diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..43988b6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.venv +venv/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +*.pdf +.pdf +plot_test/ +plot/ +performance/ +localTest/ +fig/ +figure/ +*.mp4 +*.json +Data/ControlVAE.yml +Data/Misc +Data/Pretrained +Data/utils.py +Experiment/checkpoint +Experiment/log +*.ckpt +*.STL +*.gif diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..040d7d9 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "external/dlimp"] + path = external/dlimp + url = https://github.com/kvablack/dlimp diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5522eea --- /dev/null +++ b/LICENSE @@ -0,0 +1,439 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +Copyright (c) 2016-2025 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics") + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/README.md b/README.md new file mode 100644 index 0000000..15aa4be --- /dev/null +++ b/README.md @@ -0,0 +1,228 @@ +# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family +

+ Project Page | + Models | + Dataset +

+
+

+ 🌎English | 🇨🇳中文 +

+
+
+ UnifoLM-WMA-0 is Unitree‘s open-source world-model–action architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) Simulation Engine – operates as an interactive simulator to generate synthetic data for robot learning; (b) Policy Enhancement – connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance. +
+ +## 🦾 Real-Robot Demonstrations +| | | +|:---:|:---:| +| | | + +**Note: the top-right window shows the world model’s pretion of future action videos.** + +## 🔥 News + +* Sep 22, 2025: 🚀 We released the deployment code for assisting experiments with [Unitree](https://www.unitree.com/) robots. +* Sep 15, 2025: 🚀 We released the training and inference code along with the model weights of [**UnifoLM-WMA-0**](https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c). + +## 📑 Opensource Plan +- [x] Training +- [x] Inference +- [x] Checkpoints +- [x] Deployment + +## ⚙️ Installation +``` +conda create -n unifolm-wma python==3.10.18 +conda activate unifolm-wma + +conda install pinocchio=3.2.0 -c conda-forge -y +conda install ffmpeg=7.1.1 -c conda-forge + +git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git + +# If you already downloaded the repo: +cd unifolm-world-model-action +git submodule update --init --recursive + +pip install -e . + +cd external/dlimp +pip install -e . +``` +## 🧰 Model Checkpoints +| Model | Description | Link| +|---------|-------|------| +|$\text{UnifoLM-WMA-0}_{Base}$| Fine-tuned on [Open-X](https://robotics-transformer-x.github.io/) dataset. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)| +|$\text{UnifoLM-WMA-0}_{Dual}$| Fine-tuned on five [Unitree opensource dataset](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab) in both decision-making and simulation modes. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)| + +## 🛢️ Dataset +In our experiments, we consider the following three opensource dataset: +| Dataset | Robot | Link | +|---------|-------|------| +|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)| +|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)| +|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)| +|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)| +|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)| + +To train on your own dataset, first to have the data following the [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) dataset format. Assume the dataset’s source directory structure is as follows: +``` +source_dir/ + ├── dataset1_name + ├── dataset2_name + ├── dataset3_name + └── ... +``` +Then, convert a dataset to the required format using the command below: +```python +cd prepare_data +python prepare_training_data.py \ + --source_dir /path/to/your/source_dir \ + --target_dir /path/to/save/the/converted/data \ + --dataset_name "dataset1_name" \ + --robot_name "a tag of the robot in the dataset" # e.g, Unitree Z1 Robot Arm or Unitree G1 Robot with Gripper. +``` +The resulting data structure (Note: model training only supports input from the main-view camera. If the dataset includes multiple views, remove the corresponding values from the ```data_dir``` column in the CSV file. +``` +target_dir/ + ├── videos + │ ├──dataset1_name + │ │ ├──camera_view_dir + │ │ ├── 0.mp4 + │ │ ├── 1.mp4 + │ │ └── ... + │ └── ... + ├── transitions + │ ├── dataset1_name + │ ├── meta_data + │ ├── 0.h5 + │ ├── 1.h5 + │ └── ... + └── dataset1_name.csv +``` +## 🚴‍♂️ Training +A. Our training strategy is outlined as follows: +- **Step 1**: Fine-tune a video generation model as the world model using the [Open-X](https://robotics-transformer-x.github.io/) dataset; +- **Step 2**: Post-train $\text{UnifoLM-WMA}$ in decision-making mode on the downstream task dataset; +
+ +
+- **Step 3**: Post-train $\text{UnifoLM-WMA}$ in simulation mode on the downstream task dataset. +
+ +
+**Note**: If you only require $\text{UnifoLM-WMA}$ to operate in a single mode, you may skip the corresponding step. + +B. To conduct training on a single or multiple datasets, please follow the steps below: +- **Step 1**: The maximum DoF is assumed to be 16, if you have more than 16 DoF, update ```agent_state_dim``` and ```agent_action_dim``` in [configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) ; +- **Step 2**: Set up the input shapes for each modality in [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json); +- **Step 3**: Configure the training parameters in [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml). For the ```pretrained_checkpoint```, we recommend using the checkpoint " $\text{UnifoLM-WMA-0}_{Base}$ " fine-tuned on the [Open-X](https://robotics-transformer-x.github.io/) dataset; + ```yaml + model: + pretrained_checkpoint: /path/to/pretrained/checkpoint; + ... + decision_making_only: True # Train the world model only in decision-making mode. If False, jointly train it in both decision-making and simulation modes. + ... + data: + ... + train: + ... + data_dir: /path/to/training/dataset/directory + dataset_and_weights: # list the name of each dataset below and make sure the summation of weights is 1.0 + dataset1_name: 0.2 + dataset2_name: 0.2 + dataset3_name: 0.2 + dataset4_name: 0.2 + dataset5_name: 0.2 + ``` +- **Step 4**: Setup ```experiment_name```, ```save_root``` variables in [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh); +- **Step 5**: Launch the training with the command: +``` +bash scripts/train.sh +``` +## 🌏 Inference under Interactive Simulation Mode +To run the world model in an interactive simulation mode, follow these steps: +- **Step 1**: (Skip this step if you just would like to test using the examples we provided) Prepare your own prompt following the format used in the [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts): + ``` + world_model_interaction_prompts/ + ├── images + │ ├── dataset1_name + │ │ ├── 0.png # Image prompt + │ │ └── ... + │ └── ... + ├── transitions + │ ├── dataset1_name + │ │ ├── meta_data # Used for normalization + │ │ ├── 0.h # Robot state and action data; in interaction mode, + │ │ │ # only used to retrieve the robot state corresponding + │ │ │ # to the image prompt + │ │ └── ... + │ └── ... + ├── dataset1_name.csv # File for loading image prompts, text instruction and corresponding robot states + └── ... + ``` +- **Step 2**: Specify the correct paths for ```pretrained_checkpoint```(e.g, $\text{UnifoLM-WMA-0}_{Dual}$) and ```data_dir``` in [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml) +- **Step 3**: Set the paths for ```checkpoint```, ```res_dir``` and ```prompt_dir``` in [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh), and specify all the dataset's name in ```datasets=(...)```. Then, launch the inference with the command: + ``` + bash scripts/run_world_model_interaction.sh + ``` + +## 🧠 Inference and Deployment under Decision-Making Mode + +In this setup, inference is performed on a server, while a robot client gathers observations from the real-robot and sends them to the server to query actions. The process unfolds through the following steps: + +### Server Setup: +- **Step-1**: Specify ```ckpt```, ```res_dir```, ```datasets``` in [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh); +- **Step-2**: Configure ```data_dir``` and ```dataset_and_weights``` in [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225); +- **Step-3**: Launch the server: +``` +conda activate unifolm-wma +cd unifolm-world-model-action +bash scripts/run_real_eval_server.sh +``` + +### Client Setup +- **Step-1**: Follow the instructions in [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md) to create the ```unitree_deploy``` conda environment, install the required packages, launch the controllers or services on the real-robot. +- **Step-2**: Open a new terminal and establish a tunnel connection from the client to the server: +``` +ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000 +``` +- **Step-3**: Run the ```unitree_deploy/robot_client.py``` script to start inference: +``` +cd unitree_deploy +python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15 +``` + +## 📝 Codebase Architecture +Here's a high-level overview of the project's code structure and core components: +``` +unitree-world-model/ + ├── assets # Media assets such as GIFs, images, and demo videos + ├── configs # Configuration files for training and inference + │ ├── inference + │ └── train + ├── examples # Example inputs and prompts for running inference + ├── external # External packages + ├── prepare_data # Scripts for dataset preprocessing and format conversion + ├── scripts # Main scripts for training, evaluation, and deployment + ├── src + │ ├──unitree_worldmodel # Core Python package for the Unitree world model + │ │ ├── data # Dataset loading, transformations, and dataloaders + │ │ ├── models # Model architectures and backbone definitions + │ │ ├── modules # Custom model modules and components + │ │ └── utils # Utility functions and common helpers + └── unitree_deploy # Deployment code +``` + +## 🙏 Acknowledgement +Lots of code are inherited from [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) and [HPT](https://github.com/liruiw/HPT). + +## 📝 Citation +``` +@misc{unifolm-wma-0, + author = {Unitree}, + title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family}, + year = {2025}, +} +``` diff --git a/README_cn.md b/README_cn.md new file mode 100644 index 0000000..143f998 --- /dev/null +++ b/README_cn.md @@ -0,0 +1,216 @@ +# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family +

+ 项目主页 | + 开源模型 | + 开源数据 +

+
+

+ 🌎English | 🇨🇳中文 +

+
+ +**UnifoLM-WMA-0** 是宇树科技跨多类机器人本体的开源世界模型-动作架构,专为通用机器人学习而设计。其核心成分在于一个可以理解机器人与环境交互物理规律的世界模型。该世界模型具备两大核心功能:(1)**仿真引擎**,作为交互式仿真器运行,为机器人学习提供合成数据;(2)**策略增强**,可与一个动作头进行对接,通过预测未来与物理世界的交互过程,进一步优化决策性能。模型的真机部署效果如下所示,其中右上角小窗口是世界模型对于未来环境变化的预测,可辅助控制指令生成。 + +## 🦾 真机效果 + +| | | +|:---:|:---:| +| | | + +**注:右上角小窗口显示世界模型对未来动作视频的预测。** + +## 新闻 +* 2025年9月22日: 🚀 我们发布了应用宇树科技机器人进行真机实验的部署代码. +* 2025年9月15日: 🚀 我们发布了 **UnifoLM-WMA-0** 的训练与推理代码,以及对应的模型权重. + + +## 📑 开源计划 +- [x] 训练代码 +- [x] 推理代码 +- [x] 模型Checkpoints +- [x] 真机部署代码 + +## ⚙️ 安装 +``` +conda create -n unifolm-wma python==3.10.18 +conda activate unifolm-wma + +conda install pinocchio=3.2.0 -c conda-forge -y +conda install ffmpeg=7.1.1 -c conda-forge + +git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git + +# If you already downloaded the repo: +cd unifolm-world-model-action +git submodule update --init --recursive + +pip install -e . + +cd external/dlimp +pip install -e . +``` +## 🧰 模型 Checkpoints +| 模型 | 描述 | 链接 | +|---------|-------|------| +|$\text{UnifoLM-WMA-0}_{Base}$| 在 [Open-X](https://robotics-transformer-x.github.io/) 数据集微调后的模型 | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)| +|$\text{UnifoLM-WMA-0}_{Dual}$| 在五个[宇树科技开源数据集](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab)上,决策和仿真双模式,联合微调后的模型 | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)| + +## 🛢️ 数据集 +实验中,我们训练测试了如下五个开源数据集: +| 数据集 | 机器人 | 链接 | +|---------|-------|------| +|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)| +|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)| +|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)| +|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)| +|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)| + +要在自定义数据集上训练,请首先确保数据符合 [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) 数据集格式,假设下载后的数据目录结构如下: +``` +source_dir/ + ├── dataset1_name + ├── dataset2_name + ├── dataset3_name + └── ... +``` +随后执行以下命令进行格式转换: +```python +cd prepare_data +python prepare_training_data.py \ + --source_dir /path/to/your/source_dir \ + --target_dir /path/to/save/the/converted/data/directory \ + --dataset_name "dataset1_name" \ + --robot_name "a tag of the robot in the dataset" # 例如: Unitree Z1 Robot Arm 或 Unitree G1 Robot with Gripper。 +``` +转换后的数据结构如下(注:模型训练只支持主视角相机输入, 如数据存在腕部视角,需删除CSV文件中```data_dir```列对应的视频路径): +``` +target_dir/ + ├── videos + │ ├──dataset1_name + │ │ ├──camera_view_dir + │ │ ├── 0.mp4 + │ │ ├── 1.mp4 + │ │ └── ... + │ └── ... + ├── transitions + │ ├── dataset1_name + │ │ ├── meta_data + │ │ ├── 0.h5 + │ │ ├── 1.h5 + │ │ └── ... + │ └── ... + └── dataset1_name.csv +``` +## 🚴 ♂️ 模型训练 +一. 我们的训练策略概括如下: +- **步骤 1**:在 [Open-X](https://robotics-transformer-x.github.io/) 数据集上微调视频生成模型,使其作为世界模型(World Model); +- **步骤 2**:在下游任务数据集上,对 $\text{UnifoLM-WMA}$ 进行决策模式(decision-making mode)后训练; +
+ +
+- **步骤 3**:在下游任务数据集上,对 $\text{UnifoLM-WMA}$ 进行仿真模式(simulation mode)后训练。 +
+ +
+**注意**:如果只需要 $\text{UnifoLM-WMA}$ 在单一模式下运行,可以跳过相应的步骤。 + +二. 在单个或多个数据集上进行训练,请按照以下步骤操作: +- **步骤1**:默认的最高自由度为16DOF,若需更多自由度,请修改[configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) 中 ```agent_state_dim``` 及 ```agent_action_dim``` 的数值; +- **步骤2**:在 [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json) 中为每种模态设置输入维度; +- **步骤3**: 在 [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml) 中配置训练参数及路径。关于预训练的模型,推荐使用 $\text{UnifoLM-WMA-0}_{Base}$ ,其在[Open-X](https://robotics-transformer-x.github.io/) 数据集上微调过; + ```yaml + model: + pretrained_checkpoint: /path/to/pretrained/checkpoint + ... + dicision_making_only: True # 是否只训练世界模型决策模式?如果否,则决策模式与仿真模式联合训练。 + ... + data: + ... + train: + ... + data_dir: /path/to/training/dataset/directory + dataset_and_weights: # 列出所有数据集的名称及权重,确保权重和为1.0 + dataset1_name: 0.2 + dataset2_name: 0.2 + dataset3_name: 0.2 + dataset4_name: 0.2 + dataset5_name: 0.2 + ``` +- **步骤4**: 在 [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh) 中配置```experiment_name```, ```save_root``` 变量; +- **步骤5**: 运行如下指令开启训练: +``` +bash scripts/train.sh +``` +## 🌏 世界模型交互推理 +要启用世界模型的交互模式,请按以下步骤操作: +- **步骤1**:(若仅用提供的实例进行测试,可跳过此步) 请按照 [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts) 目录中的格式,自定义提示词目录: +``` +world_model_interaction_prompts/ + ├── images + │ ├── dataset1_name + │ │ ├── 0.png # 图像提示词 + │ │ └── ... + │ └── ... + ├── transitions + │ ├── dataset1_name + │ │ ├── meta_data # 用于归一化 + │ │ ├── 0.h # 机器人状态、动作相关数据,在交互模式下仅用于获取与图像提示词对应的机器人状态 + │ │ └── ... + │ └── ... + ├── dataset1_name.csv # 该文件用于加载对应的:图像提示词、文本指令及机器人状态 + └── ... +``` +- **步骤2**: 在 [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml) 中指定 ```pretrained_checkpoint``` (例如:$\text{UnifoLM-WMA-0}_{Dual}$) 和 ```data_dir``` 的正确路径; +- **步骤3**: 在 [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh) 中指定```checkpoint```、```res_dir``` 和 ```prompt_dir```的正确路径,并在```datasets=(...)```中列出测试的数据集名称,然后用下述指令启动推理: + ``` + bash scripts/run_world_model_interaction.sh + ``` + +## 🧠 世界模型决策推理及部署 +在我们的系统中,推理在服务器端执行;机器人客户端从真实机器人收集观测信息并发送至服务器, 进行视频及动作推理。可通过如下步骤实现整个过程: + +### 服务器端设置 +- **步骤1**: 在 [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh) 中指定 ```ckpt```、```res_dir```、```datasets```; +- **步骤2**: 在 [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225) 中配置 ```data_dir``` 和 ```dataset_and_weights```; +- **步骤3**: 启动服务器: +``` +conda activate unifolm-wma +cd unifolm-world-model-action +bash scripts/run_real_eval_server.sh +``` + +### 客户端设置 +- **步骤1**: 参考 [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md),创建 ```unitree_deploy``` conda 环境,安装所需依赖包,并在真实机器人端启动控制器或服务; +- **步骤2**: 打开一个新的终端,从客户端到服务器建立隧道连接: +``` +ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000 +``` +- **步骤3**: 运行 ```unitree_deploy/robot_client.py``` 脚本以启动推理: +``` +cd unitree_deploy +python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15 +``` + +## 📝 代码架构 +以下是本项目代码结构设计及核心组件说明:: +``` +unitree-world-model/ + ├── assets # GIF动图、静态图片和演示视频等媒体素材 + ├── configs # 配置文件 + │ ├── inference + │ └── train + ├── examples # 示例数据 + ├── external # 外部代码库 + ├── prepare_data # 数据处理 + ├── scripts # 主程序脚本 + ├── src + │ ├──unitree_worldmodel # 核心库 + │ │ ├── data # 数据加载 + │ │ ├── models # 模型架构 + │ │ ├── modules # 自定义模块 + | │ └── utils # 工具函数 +``` + +## 🙏 致谢声明 +本项目代码基于以下优秀开源项目构建,特此致谢:[DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) 和 [HPT](https://github.com/liruiw/HPT). diff --git a/assets/pngs/dm_mode.png b/assets/pngs/dm_mode.png new file mode 100644 index 0000000..b377a3d Binary files /dev/null and b/assets/pngs/dm_mode.png differ diff --git a/assets/pngs/sim_mode.png b/assets/pngs/sim_mode.png new file mode 100644 index 0000000..6c96b09 Binary files /dev/null and b/assets/pngs/sim_mode.png differ diff --git a/ckpts/.gitattributes b/ckpts/.gitattributes new file mode 100644 index 0000000..c65e9b2 --- /dev/null +++ b/ckpts/.gitattributes @@ -0,0 +1,54 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text + +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text + +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text +*.tfevents* filter=lfs diff=lfs merge=lfs -text +*.db* filter=lfs diff=lfs merge=lfs -text +*.ark* filter=lfs diff=lfs merge=lfs -text +**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text +**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text +**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text + +*.gguf* filter=lfs diff=lfs merge=lfs -text +*.ggml filter=lfs diff=lfs merge=lfs -text +*.llamafile* filter=lfs diff=lfs merge=lfs -text +*.pt2 filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text + +assets/real_cleanup_pencils.gif filter=lfs diff=lfs merge=lfs -text +assets/world_model_interaction.gif filter=lfs diff=lfs merge=lfs -text +assets/real_dual_stackbox.gif filter=lfs diff=lfs merge=lfs -text +assets/real_g1_pack_camera.gif filter=lfs diff=lfs merge=lfs -text +assets/real_z1_stackbox.gif filter=lfs diff=lfs merge=lfs -text +unifolm_wma_dual.ckpt filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/ckpts/LICENSE b/ckpts/LICENSE new file mode 100644 index 0000000..5522eea --- /dev/null +++ b/ckpts/LICENSE @@ -0,0 +1,439 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +Copyright (c) 2016-2025 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics") + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/ckpts/README.md b/ckpts/README.md new file mode 100644 index 0000000..038ff9b --- /dev/null +++ b/ckpts/README.md @@ -0,0 +1,38 @@ +--- +tags: +- robotics +--- + +# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family +

+ Project Page | + Code | + Dataset +

+
+
+ UnifoLM-WMA-0 is Unitree‘s first open-source world-model–action architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) Simulation Engine – operates as an interactive simulator to generate synthetic data for robot learning; (b) Policy Enhancement – connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance. +
+
+ +## 🦾 Real Robot Deployment +| | | +|:---:|:---:| +| | | + +**Note: the top-right window shows the world model’s prediction of future environmental changes.** + +## License +The model is released under the CC BY-NC-SA 4.0 license as found in the [LICENSE](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0/blob/main/LICENSE). You are responsible for ensuring that your use of Unitree AI Models complies with all applicable laws. + +## Model Architecture +![Demo](assets/world_model_interaction.gif) + +## Citation +``` +@misc{unifolm-wma-0, + author = {Unitree}, + title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family}, + year = {2025}, +} +``` \ No newline at end of file diff --git a/configs/inference/base_model_inference.yaml b/configs/inference/base_model_inference.yaml new file mode 100644 index 0000000..157000a --- /dev/null +++ b/configs/inference/base_model_inference.yaml @@ -0,0 +1,213 @@ +model: + target: unifolm_wma.models.ddpms.LatentVisualDiffusion + params: + rescale_betas_zero_snr: True + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + timesteps: 1000 + first_stage_key: video + cond_stage_key: instruction + cond_stage_trainable: False + conditioning_key: hybrid + image_size: [40, 64] + channels: 4 + scale_by_std: False + scale_factor: 0.18215 + use_ema: False + uncond_type: 'empty_seq' + use_dynamic_rescale: true + base_scale: 0.7 + fps_condition_type: 'fps' + perframe_ae: True + freeze_embedder: True + n_obs_steps_imagen: 1 + n_obs_steps_acting: 1 + agent_state_dim: 16 + agent_action_dim: 16 + + ###################### DP Related + input_pertub: 0.1 + lr_scheduler: cosine + lr_warmup_steps: 2000 + num_epochs: 30000 + gradient_accumulate_every: 1 + use_scheduler: True + dp_use_ema: True + + dp_ema_config: + target: unifolm_wma.models.diffusion_head.ema_model.EMAModel + params: + update_after_step: 0 + inv_gamma: 1.0 + power: 0.75 + min_value: 0.0 + max_value: 0.9999 + + noise_scheduler_config: + target: diffusers.DDIMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.0001 + beta_end: 0.02 + beta_schedule: squaredcos_cap_v2 + clip_sample: True + set_alpha_to_one: True + steps_offset: 0 + prediction_type: epsilon + + dp_optimizer_config: + target: torch.optim.AdamW + params: + lr: 1.0e-4 + betas: [0.95, 0.999] + eps: 1.0e-8 + weight_decay: 1.0e-6 + + wma_config: + target: unifolm_wma.modules.networks.wma_model.WMAModel + params: + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + dropout: 0.1 + num_head_channels: 64 + transformer_depth: 1 + context_dim: 1024 + use_linear: true + use_checkpoint: True + temporal_conv: True + temporal_attention: True + temporal_selfatt_only: True + use_relative_position: False + use_causal_attention: False + temporal_length: 16 + addition_attention: True + image_cross_attention: True + default_fs: 10 + fs_condition: True + cross_attention_scale_learnable: False + n_obs_steps: ${model.params.n_obs_steps_imagen} + num_stem_token: 16 + base_model_gen_only: True + + unet_head_config: + target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D + params: + input_dim: ${model.params.agent_action_dim} + n_obs_steps: ${model.params.n_obs_steps_acting} + diffusion_step_embed_dim: 128 + down_dims: [256, 512, 1024, 2048] + kernel_size: 5 + n_groups: 8 + cond_predict_scale: True + num_head_channels: ${model.params.wma_config.params.num_head_channels} + horizon: ${model.params.wma_config.params.temporal_length} + use_linear_attn: ${model.params.wma_config.params.use_linear} + use_linear_act_proj: True + act_proj_dim: 32 + cond_cross_attention: False + context_dims: [] + image_size: ${model.params.image_size} + imagen_cond_gradient: True + last_frame_only: False + use_imagen_mid_only: False + use_z_only: False + + obs_encoder_config: + target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder + params: + rgb_model_config: + target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet + params: + name: resnet18 + weights: null + resize_shape: null + crop_shape: null + random_crop: False + use_group_norm: True + share_rgb_model: False + imagenet_norm: True + use_spatial_softmax: True + spatial_softmax_kp: 128 + + ###################### Action Tokenization + stem_process_config: + target: unifolm_wma.modules.encoders.condition.SATokenProjector + params: + dim: 1024 + depth: 1 + dim_head: 64 + heads: 16 + num_queries: ${model.params.wma_config.params.num_stem_token} + output_dim: 1024 + ff_mult: 4 + chunk_size: ${model.params.wma_config.params.temporal_length} + + first_stage_config: + target: unifolm_wma.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + img_cond_stage_config: + target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 + params: + freeze: true + + image_proj_stage_config: + target: unifolm_wma.modules.encoders.resampler.Resampler + params: + dim: 1024 + depth: 4 + dim_head: 64 + heads: 12 + num_queries: 16 + embedding_dim: 1280 + output_dim: 1024 + ff_mult: 4 + video_length: ${model.params.wma_config.params.temporal_length} + + normalization_config: + input_shapes: + observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim} + input_normalization_modes: + observation.state: 'min_max' + output_shapes: + action: ${model.params.wma_config.params.action_unet_config.params.input_dim} + output_normalization_modes: + action: 'min_max' diff --git a/configs/inference/world_model_decision_making.yaml b/configs/inference/world_model_decision_making.yaml new file mode 100644 index 0000000..2f51dce --- /dev/null +++ b/configs/inference/world_model_decision_making.yaml @@ -0,0 +1,240 @@ +model: + target: unifolm_wma.models.ddpms.LatentVisualDiffusion + params: + rescale_betas_zero_snr: True + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + timesteps: 1000 + first_stage_key: video + cond_stage_key: instruction + cond_stage_trainable: False + conditioning_key: hybrid + image_size: [40, 64] + channels: 4 + scale_by_std: False + scale_factor: 0.18215 + use_ema: False + uncond_type: 'empty_seq' + use_dynamic_rescale: true + base_scale: 0.7 + fps_condition_type: 'fps' + perframe_ae: True + freeze_embedder: True + n_obs_steps_imagen: 2 + n_obs_steps_acting: 2 + agent_state_dim: 16 + agent_action_dim: 16 + decision_making_only: True + + ###################### DP Related + input_pertub: 0.1 + lr_scheduler: cosine + lr_warmup_steps: 2000 + num_epochs: 30000 + gradient_accumulate_every: 1 + use_scheduler: True + dp_use_ema: True + + dp_ema_config: + target: unifolm_wma.models.diffusion_head.ema_model.EMAModel + params: + update_after_step: 0 + inv_gamma: 1.0 + power: 0.75 + min_value: 0.0 + max_value: 0.9999 + + noise_scheduler_config: + target: diffusers.DDIMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.0001 + beta_end: 0.02 + beta_schedule: squaredcos_cap_v2 + clip_sample: True + set_alpha_to_one: True + steps_offset: 0 + prediction_type: epsilon + + dp_optimizer_config: + target: torch.optim.AdamW + params: + lr: 1.0e-4 + betas: [0.95, 0.999] + eps: 1.0e-8 + weight_decay: 1.0e-6 + + wma_config: + target: unifolm_wma.modules.networks.wma_model.WMAModel + params: + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + dropout: 0.1 + num_head_channels: 64 + transformer_depth: 1 + context_dim: 1024 + use_linear: true + use_checkpoint: True + temporal_conv: True + temporal_attention: True + temporal_selfatt_only: True + use_relative_position: False + use_causal_attention: False + temporal_length: 16 + addition_attention: True + image_cross_attention: True + default_fs: 10 + fs_condition: True + cross_attention_scale_learnable: False + n_obs_steps: ${model.params.n_obs_steps_imagen} + num_stem_token: 16 + base_model_gen_only: False + + unet_head_config: + target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D + params: + input_dim: ${model.params.agent_action_dim} + n_obs_steps: ${model.params.n_obs_steps_acting} + diffusion_step_embed_dim: 128 + down_dims: [256, 512, 1024, 2048] + kernel_size: 5 + n_groups: 8 + cond_predict_scale: True + num_head_channels: ${model.params.wma_config.params.num_head_channels} + horizon: ${model.params.wma_config.params.temporal_length} + use_linear_attn: ${model.params.wma_config.params.use_linear} + use_linear_act_proj: True + act_proj_dim: 32 + cond_cross_attention: False + context_dims: [] + image_size: ${model.params.image_size} + imagen_cond_gradient: True + last_frame_only: False + use_imagen_mid_only: False + use_z_only: False + + obs_encoder_config: + target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder + params: + rgb_model_config: + target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet + params: + name: resnet18 + weights: null + resize_shape: null + crop_shape: null + random_crop: False + use_group_norm: True + share_rgb_model: False + imagenet_norm: True + use_spatial_softmax: True + spatial_softmax_kp: 128 + + ###################### Action Tokenization + stem_process_config: + target: unifolm_wma.modules.encoders.condition.SATokenProjector + params: + dim: 1024 + depth: 1 + dim_head: 64 + heads: 16 + num_queries: ${model.params.wma_config.params.num_stem_token} + output_dim: 1024 + ff_mult: 4 + chunk_size: ${model.params.wma_config.params.temporal_length} + + first_stage_config: + target: unifolm_wma.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + img_cond_stage_config: + target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 + params: + freeze: true + + image_proj_stage_config: + target: unifolm_wma.modules.encoders.resampler.Resampler + params: + dim: 1024 + depth: 4 + dim_head: 64 + heads: 12 + num_queries: 16 + embedding_dim: 1280 + output_dim: 1024 + ff_mult: 4 + video_length: ${model.params.wma_config.params.temporal_length} + + normalization_config: + input_shapes: + observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim} + input_normalization_modes: + observation.state: 'min_max' + output_shapes: + action: ${model.params.wma_config.params.action_unet_config.params.input_dim} + output_normalization_modes: + action: 'min_max' + +data: + target: unifolm_wma.utils.data.DataModuleFromConfig + params: + batch_size: 6 + num_workers: 12 + wrap: False + test: + target: unifolm_wma.data.wma_data.WMAData + params: + data_dir: '/path/to/the/dataset/directory/that/contains/the/meta/folder/of/the/testing/case/under/a/transitions/folder' # e.g., /path/to/unifolm-world-model-action/examples/world_model_interaction_prompts + video_length: ${model.params.wma_config.params.temporal_length} + frame_stride: 2 + load_raw_resolution: True + resolution: [320, 512] + spatial_transform: resize_center_crop + crop_resolution: [320, 512] + random_fs: False + cond_robot_label_prob: 0.0 + normalization_mode: 'min_max' + individual_normalization: True + n_obs_steps: ${model.params.n_obs_steps_imagen} + max_action_dim: ${model.params.agent_action_dim} + max_state_dim: ${model.params.agent_state_dim} + dataset_and_weights: + unitree_g1_pack_camera: 1.0 diff --git a/configs/inference/world_model_interaction.yaml b/configs/inference/world_model_interaction.yaml new file mode 100644 index 0000000..2d69e09 --- /dev/null +++ b/configs/inference/world_model_interaction.yaml @@ -0,0 +1,244 @@ +model: + target: unifolm_wma.models.ddpms.LatentVisualDiffusion + params: + rescale_betas_zero_snr: True + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + timesteps: 1000 + first_stage_key: video + cond_stage_key: instruction + cond_stage_trainable: False + conditioning_key: hybrid + image_size: [40, 64] + channels: 4 + scale_by_std: False + scale_factor: 0.18215 + use_ema: False + uncond_type: 'empty_seq' + use_dynamic_rescale: true + base_scale: 0.7 + fps_condition_type: 'fps' + perframe_ae: True + freeze_embedder: True + n_obs_steps_imagen: 2 + n_obs_steps_acting: 2 + agent_state_dim: 16 + agent_action_dim: 16 + decision_making_only: False + + ###################### DP Related + input_pertub: 0.1 + lr_scheduler: cosine + lr_warmup_steps: 2000 + num_epochs: 30000 + gradient_accumulate_every: 1 + use_scheduler: True + dp_use_ema: True + + dp_ema_config: + target: unifolm_wma.models.diffusion_head.ema_model.EMAModel + params: + update_after_step: 0 + inv_gamma: 1.0 + power: 0.75 + min_value: 0.0 + max_value: 0.9999 + + noise_scheduler_config: + target: diffusers.DDIMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.0001 + beta_end: 0.02 + beta_schedule: squaredcos_cap_v2 + clip_sample: True + set_alpha_to_one: True + steps_offset: 0 + prediction_type: epsilon + + dp_optimizer_config: + target: torch.optim.AdamW + params: + lr: 1.0e-4 + betas: [0.95, 0.999] + eps: 1.0e-8 + weight_decay: 1.0e-6 + + wma_config: + target: unifolm_wma.modules.networks.wma_model.WMAModel + params: + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + dropout: 0.1 + num_head_channels: 64 + transformer_depth: 1 + context_dim: 1024 + use_linear: true + use_checkpoint: True + temporal_conv: True + temporal_attention: True + temporal_selfatt_only: True + use_relative_position: False + use_causal_attention: False + temporal_length: 16 + addition_attention: True + image_cross_attention: True + default_fs: 10 + fs_condition: True + cross_attention_scale_learnable: False + n_obs_steps: ${model.params.n_obs_steps_imagen} + num_stem_token: 16 + base_model_gen_only: False + + unet_head_config: + target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D + params: + input_dim: ${model.params.agent_action_dim} + n_obs_steps: ${model.params.n_obs_steps_acting} + diffusion_step_embed_dim: 128 + down_dims: [256, 512, 1024, 2048] + kernel_size: 5 + n_groups: 8 + cond_predict_scale: True + num_head_channels: ${model.params.wma_config.params.num_head_channels} + horizon: ${model.params.wma_config.params.temporal_length} + use_linear_attn: ${model.params.wma_config.params.use_linear} + use_linear_act_proj: True + act_proj_dim: 32 + cond_cross_attention: False + context_dims: [] + image_size: ${model.params.image_size} + imagen_cond_gradient: True + last_frame_only: False + use_imagen_mid_only: False + use_z_only: False + + obs_encoder_config: + target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder + params: + rgb_model_config: + target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet + params: + name: resnet18 + weights: null + resize_shape: null + crop_shape: null + random_crop: False + use_group_norm: True + share_rgb_model: False + imagenet_norm: True + use_spatial_softmax: True + spatial_softmax_kp: 128 + + ###################### Action Tokenization + stem_process_config: + target: unifolm_wma.modules.encoders.condition.SATokenProjector + params: + dim: 1024 + depth: 1 + dim_head: 64 + heads: 16 + num_queries: ${model.params.wma_config.params.num_stem_token} + output_dim: 1024 + ff_mult: 4 + chunk_size: ${model.params.wma_config.params.temporal_length} + + first_stage_config: + target: unifolm_wma.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + img_cond_stage_config: + target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 + params: + freeze: true + + image_proj_stage_config: + target: unifolm_wma.modules.encoders.resampler.Resampler + params: + dim: 1024 + depth: 4 + dim_head: 64 + heads: 12 + num_queries: 16 + embedding_dim: 1280 + output_dim: 1024 + ff_mult: 4 + video_length: ${model.params.wma_config.params.temporal_length} + + normalization_config: + input_shapes: + observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim} + input_normalization_modes: + observation.state: 'min_max' + output_shapes: + action: ${model.params.wma_config.params.action_unet_config.params.input_dim} + output_normalization_modes: + action: 'min_max' + +data: + target: unifolm_wma.utils.data.DataModuleFromConfig + params: + batch_size: 6 + num_workers: 12 + wrap: False + test: + target: unifolm_wma.data.wma_data.WMAData + params: + data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts' + video_length: ${model.params.wma_config.params.temporal_length} + frame_stride: 2 + load_raw_resolution: True + resolution: [320, 512] + spatial_transform: resize_center_crop + crop_resolution: [320, 512] + random_fs: False + cond_robot_label_prob: 0.0 + normalization_mode: 'min_max' + individual_normalization: True + n_obs_steps: ${model.params.n_obs_steps_imagen} + max_action_dim: ${model.params.agent_action_dim} + max_state_dim: ${model.params.agent_state_dim} + dataset_and_weights: + unitree_z1_stackbox: 0.2 + unitree_z1_dual_arm_stackbox: 0.2 + unitree_z1_dual_arm_stackbox_v2: 0.2 + unitree_z1_dual_arm_cleanup_pencils: 0.2 + unitree_g1_pack_camera: 0.2 diff --git a/configs/train/config.yaml b/configs/train/config.yaml new file mode 100644 index 0000000..9e7bb06 --- /dev/null +++ b/configs/train/config.yaml @@ -0,0 +1,287 @@ +model: + pretrained_checkpoint: /path/to/pretrained/checkpoint + base_learning_rate: 1.0e-05 + scale_lr: False + target: unifolm_wma.models.ddpms.LatentVisualDiffusion + params: + rescale_betas_zero_snr: True + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: video + cond_stage_key: instruction + cond_stage_trainable: False + image_proj_model_trainable: True + conditioning_key: hybrid + image_size: [40, 64] + channels: 4 + scale_by_std: False + scale_factor: 0.18215 + use_ema: False + uncond_prob: 0.05 + uncond_type: 'empty_seq' + rand_cond_frame: false + use_dynamic_rescale: true + base_scale: 0.7 + fps_condition_type: 'fps' + perframe_ae: True + freeze_embedder: True + n_obs_steps_imagen: 2 + n_obs_steps_acting: 2 + agent_state_dim: 16 + agent_action_dim: 16 + decision_making_only: True + + ###################### DP Related + input_pertub: 0.1 + lr_scheduler: cosine + lr_warmup_steps: 2000 + num_epochs: 60000 + gradient_accumulate_every: 1 + use_scheduler: True + dp_use_ema: True + + dp_ema_config: + target: unifolm_wma.models.diffusion_head.ema_model.EMAModel + params: + update_after_step: 0 + inv_gamma: 1.0 + power: 0.75 + min_value: 0.0 + max_value: 0.9999 + + noise_scheduler_config: + target: diffusers.DDIMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.0001 + beta_end: 0.02 + beta_schedule: squaredcos_cap_v2 + clip_sample: True + set_alpha_to_one: True + steps_offset: 0 + prediction_type: epsilon + + dp_optimizer_config: + target: torch.optim.AdamW + params: + lr: 1.0e-4 + betas: [0.95, 0.999] + eps: 1.0e-8 + weight_decay: 1.0e-6 + + wma_config: + target: unifolm_wma.modules.networks.wma_model.WMAModel + params: + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + dropout: 0.1 + num_head_channels: 64 + transformer_depth: 1 + context_dim: 1024 + use_linear: true + use_checkpoint: True + temporal_conv: True + temporal_attention: True + temporal_selfatt_only: True + use_relative_position: False + use_causal_attention: False + temporal_length: 16 + addition_attention: True + image_cross_attention: True + default_fs: 10 + fs_condition: True + cross_attention_scale_learnable: False + n_obs_steps: ${model.params.n_obs_steps_imagen} + num_stem_token: 16 + base_model_gen_only: False + + unet_head_config: + target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D + params: + input_dim: ${model.params.agent_action_dim} + n_obs_steps: ${model.params.n_obs_steps_acting} + diffusion_step_embed_dim: 128 + down_dims: [256, 512, 1024, 2048] + kernel_size: 5 + n_groups: 8 + cond_predict_scale: True + num_head_channels: ${model.params.wma_config.params.num_head_channels} + horizon: ${model.params.wma_config.params.temporal_length} + use_linear_attn: ${model.params.wma_config.params.use_linear} + use_linear_act_proj: True + act_proj_dim: 32 + cond_cross_attention: False + context_dims: [] + image_size: ${model.params.image_size} + imagen_cond_gradient: True + last_frame_only: False + use_imagen_mid_only: False + use_z_only: False + + obs_encoder_config: + target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder + params: + rgb_model_config: + target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet + params: + name: resnet18 + weights: null + resize_shape: null + crop_shape: null + random_crop: False + use_group_norm: True + share_rgb_model: False + imagenet_norm: True + use_spatial_softmax: True + spatial_softmax_kp: 128 + + ###################### Action Tokenization + stem_process_config: + target: unifolm_wma.modules.encoders.condition.SATokenProjector + params: + dim: 1024 + depth: 1 + dim_head: 64 + heads: 16 + num_queries: ${model.params.wma_config.params.num_stem_token} + output_dim: 1024 + ff_mult: 4 + chunk_size: ${model.params.wma_config.params.temporal_length} + + first_stage_config: + target: unifolm_wma.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + img_cond_stage_config: + target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 + params: + freeze: true + + image_proj_stage_config: + target: unifolm_wma.modules.encoders.resampler.Resampler + params: + dim: 1024 + depth: 4 + dim_head: 64 + heads: 12 + num_queries: 16 + embedding_dim: 1280 + output_dim: 1024 + ff_mult: 4 + video_length: ${model.params.wma_config.params.temporal_length} + + normalization_config: + input_shapes: + observation.state: ${model.params.wma_config.params.unet_head_config.params.input_dim} + input_normalization_modes: + observation.state: 'min_max' + output_shapes: + action: ${model.params.wma_config.params.unet_head_config.params.input_dim} + output_normalization_modes: + action: 'min_max' + +data: + target: unifolm_wma.utils.data.DataModuleFromConfig + params: + batch_size: 8 + num_workers: 12 + wrap: False + train: + target: unifolm_wma.data.wma_data.WMAData + params: + data_dir: '/path/to/training/dataset/directory' + video_length: ${model.params.wma_config.params.temporal_length} + frame_stride: 2 + load_raw_resolution: True + resolution: [320, 512] + spatial_transform: resize_center_crop + crop_resolution: [320, 512] + random_fs: False + cond_robot_label_prob: 0.0 + normalization_mode: 'min_max' + individual_normalization: True + n_obs_steps: ${model.params.n_obs_steps_imagen} + max_action_dim: ${model.params.agent_action_dim} + max_state_dim: ${model.params.agent_state_dim} + dataset_and_weights: + unitree_z1_stackbox: 0.2 + unitree_z1_dual_arm_stackbox: 0.2 + unitree_z1_dual_arm_stackbox_v2: 0.2 + unitree_z1_dual_arm_cleanup_pencils: 0.2 + unitree_g1_pack_camera: 0.2 + +lightning: + precision: 16 + trainer: + benchmark: True + accumulate_grad_batches: 2 + max_steps: 300000 + log_every_n_steps: 50 + val_check_interval: 1.0 + gradient_clip_algorithm: 'norm' + gradient_clip_val: 0.5 + enable_model_summary: False + callbacks: + model_checkpoint: + target: pytorch_lightning.callbacks.ModelCheckpoint + params: + every_n_train_steps: 1000 + filename: "{epoch}-{step}" + save_weights_only: True + metrics_over_trainsteps_checkpoint: + target: pytorch_lightning.callbacks.ModelCheckpoint + params: + filename: '{epoch}-{step}' + save_weights_only: True + every_n_train_steps: 10000 + batch_logger: + target: unifolm_wma.utils.callbacks.ImageLogger + params: + batch_frequency: 20000 + to_local: False + max_images: 8 + log_images_kwargs: + ddim_steps: 16 + unconditional_guidance_scale: 1.0 + timestep_spacing: uniform_trailing + guidance_rescale: 0.7 diff --git a/examples/base_model_prompts/0.png b/examples/base_model_prompts/0.png new file mode 100644 index 0000000..468c434 Binary files /dev/null and b/examples/base_model_prompts/0.png differ diff --git a/examples/base_model_prompts/1.png b/examples/base_model_prompts/1.png new file mode 100644 index 0000000..899edfb Binary files /dev/null and b/examples/base_model_prompts/1.png differ diff --git a/examples/base_model_prompts/10.png b/examples/base_model_prompts/10.png new file mode 100644 index 0000000..e195de4 Binary files /dev/null and b/examples/base_model_prompts/10.png differ diff --git a/examples/base_model_prompts/11.png b/examples/base_model_prompts/11.png new file mode 100644 index 0000000..fca3c47 Binary files /dev/null and b/examples/base_model_prompts/11.png differ diff --git a/examples/base_model_prompts/12.png b/examples/base_model_prompts/12.png new file mode 100644 index 0000000..53dda0e Binary files /dev/null and b/examples/base_model_prompts/12.png differ diff --git a/examples/base_model_prompts/13.png b/examples/base_model_prompts/13.png new file mode 100644 index 0000000..cbf5623 Binary files /dev/null and b/examples/base_model_prompts/13.png differ diff --git a/examples/base_model_prompts/14.png b/examples/base_model_prompts/14.png new file mode 100644 index 0000000..317843c Binary files /dev/null and b/examples/base_model_prompts/14.png differ diff --git a/examples/base_model_prompts/15.png b/examples/base_model_prompts/15.png new file mode 100644 index 0000000..3a1a4b2 Binary files /dev/null and b/examples/base_model_prompts/15.png differ diff --git a/examples/base_model_prompts/2.png b/examples/base_model_prompts/2.png new file mode 100644 index 0000000..b6e58fb Binary files /dev/null and b/examples/base_model_prompts/2.png differ diff --git a/examples/base_model_prompts/3.png b/examples/base_model_prompts/3.png new file mode 100644 index 0000000..39611c4 Binary files /dev/null and b/examples/base_model_prompts/3.png differ diff --git a/examples/base_model_prompts/4.png b/examples/base_model_prompts/4.png new file mode 100644 index 0000000..8ec01a0 Binary files /dev/null and b/examples/base_model_prompts/4.png differ diff --git a/examples/base_model_prompts/5.png b/examples/base_model_prompts/5.png new file mode 100644 index 0000000..53a5a49 Binary files /dev/null and b/examples/base_model_prompts/5.png differ diff --git a/examples/base_model_prompts/6.png b/examples/base_model_prompts/6.png new file mode 100644 index 0000000..b3a6312 Binary files /dev/null and b/examples/base_model_prompts/6.png differ diff --git a/examples/base_model_prompts/7.png b/examples/base_model_prompts/7.png new file mode 100644 index 0000000..f4366f8 Binary files /dev/null and b/examples/base_model_prompts/7.png differ diff --git a/examples/base_model_prompts/8.png b/examples/base_model_prompts/8.png new file mode 100644 index 0000000..efaaa93 Binary files /dev/null and b/examples/base_model_prompts/8.png differ diff --git a/examples/base_model_prompts/9.png b/examples/base_model_prompts/9.png new file mode 100644 index 0000000..c86cda7 Binary files /dev/null and b/examples/base_model_prompts/9.png differ diff --git a/examples/base_model_prompts/prompts.csv b/examples/base_model_prompts/prompts.csv new file mode 100644 index 0000000..d6df917 --- /dev/null +++ b/examples/base_model_prompts/prompts.csv @@ -0,0 +1,17 @@ +videoid,instruction,fps,start_idx,fs,num_gen +0,wash the pan,16.0,152.0,2.0,6.0 +1,pick up the blue cup and put it into the brown cup. ,5.0,0.0,2.0,4.0 +2,close top drawer,3.0,0.0,2.0,1.0 +3,Close the laptop.,10.0,0.0,2.0,4.0 +4,destack cube,5.0,0.0,2.0,3.0 +5,arrange plate and fork,20.0,40.0,2.0,4.0 +6,Place the lid on the teapot,15.0,30.0,1.0,4.0 +7,Pick up the green object and insert it.,10.0,0.0,2.0,4.0 +8,place the burger meat in the oven,10.0,0.0,2.0,2.0 +9,make a cup of coffee with the keurig machine,10.0,0.0,2.0,4.0 +10,assemble one_leg,10.0,0.0,2.0,7.0 +11,get the cloth and wipe up the spill under the wine glass,8.0,669.0,2.0,3.0 +12,palce dishes in the dish rack,10.0,0.0,2.0,4.0 +13,move redbull can near green can,3.0,3.0,2.0,1.0 +14,open the drawer,5.0,5.0,1.0,2.0 +15,sweep the green cloth to the left side of the table,5.0,0.0,2.0,3.0 diff --git a/examples/world_model_interaction_prompts/images/unitree_g1_pack_camera/0.png b/examples/world_model_interaction_prompts/images/unitree_g1_pack_camera/0.png new file mode 100644 index 0000000..b295e63 Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_g1_pack_camera/0.png differ diff --git a/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_cleanup_pencils/0.png b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_cleanup_pencils/0.png new file mode 100644 index 0000000..d919c6f Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_cleanup_pencils/0.png differ diff --git a/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox/0.png b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox/0.png new file mode 100644 index 0000000..ebcc502 Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox/0.png differ diff --git a/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox_v2/0.png b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox_v2/0.png new file mode 100644 index 0000000..17008a2 Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox_v2/0.png differ diff --git a/examples/world_model_interaction_prompts/images/unitree_z1_stackbox/0.png b/examples/world_model_interaction_prompts/images/unitree_z1_stackbox/0.png new file mode 100644 index 0000000..8ec19a0 Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_z1_stackbox/0.png differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/0.h5 new file mode 100644 index 0000000..3505d37 Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/0.h5 differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/meta_data/stats.safetensors new file mode 100644 index 0000000..3dafbcb Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/meta_data/stats.safetensors differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/0.h5 new file mode 100644 index 0000000..7835773 Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/0.h5 differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/meta_data/stats.safetensors new file mode 100644 index 0000000..e3194ab Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/meta_data/stats.safetensors differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/0.h5 new file mode 100755 index 0000000..28184eb Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/0.h5 differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/episode_data_index.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/episode_data_index.safetensors new file mode 100755 index 0000000..2bfc77f Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/episode_data_index.safetensors differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/stats.safetensors new file mode 100755 index 0000000..fa7fd40 Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/stats.safetensors differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/0.h5 new file mode 100644 index 0000000..292bcfa Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/0.h5 differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/meta_data/stats.safetensors new file mode 100644 index 0000000..6ef7a6c Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/meta_data/stats.safetensors differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/0.h5 new file mode 100755 index 0000000..9cdae1b Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/0.h5 differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/episode_data_index.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/episode_data_index.safetensors new file mode 100644 index 0000000..62bde44 Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/episode_data_index.safetensors differ diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/stats.safetensors new file mode 100644 index 0000000..1918ea0 Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/stats.safetensors differ diff --git a/examples/world_model_interaction_prompts/unitree_g1_pack_camera.csv b/examples/world_model_interaction_prompts/unitree_g1_pack_camera.csv new file mode 100644 index 0000000..0883835 --- /dev/null +++ b/examples/world_model_interaction_prompts/unitree_g1_pack_camera.csv @@ -0,0 +1,2 @@ +videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps +0,x,x,unitree_g1_pack_camera,Pack black camera into box.,x,x,x,Unitree G1 Robot with Gripper,30 diff --git a/examples/world_model_interaction_prompts/unitree_z1_dual_arm_cleanup_pencils.csv b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_cleanup_pencils.csv new file mode 100644 index 0000000..ca39924 --- /dev/null +++ b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_cleanup_pencils.csv @@ -0,0 +1,2 @@ +videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps +0,x,x,unitree_z1_dual_arm_cleanup_pencils,clean up eraser and pencils,x,x,x,Unitree Z1 Robot Dual-Arm,30 diff --git a/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox.csv b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox.csv new file mode 100644 index 0000000..08b30a5 --- /dev/null +++ b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox.csv @@ -0,0 +1,2 @@ +videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps +0,x,x,unitree_z1_dual_arm_stackbox,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top.",x,x,x,Unitree Z1 Robot Dual-Arm,30 diff --git a/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox_v2.csv b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox_v2.csv new file mode 100644 index 0000000..2581ded --- /dev/null +++ b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox_v2.csv @@ -0,0 +1,2 @@ +videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps +0,x,x,unitree_z1_dual_arm_stackbox_v2,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top",x,x,x,Unitree Z1 Robot Dual-Arm,30 diff --git a/examples/world_model_interaction_prompts/unitree_z1_stackbox.csv b/examples/world_model_interaction_prompts/unitree_z1_stackbox.csv new file mode 100755 index 0000000..a4505c1 --- /dev/null +++ b/examples/world_model_interaction_prompts/unitree_z1_stackbox.csv @@ -0,0 +1,2 @@ +videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps +0,x,x,unitree_z1_stackbox,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top.",x,x,x,Unitree Z1 Robot Arm,30 diff --git a/external/dlimp b/external/dlimp new file mode 160000 index 0000000..5edaa46 --- /dev/null +++ b/external/dlimp @@ -0,0 +1 @@ +Subproject commit 5edaa4691567873d495633f2708982b42edf1972 diff --git a/model_architecture_analysis.md b/model_architecture_analysis.md new file mode 100644 index 0000000..4114c2c --- /dev/null +++ b/model_architecture_analysis.md @@ -0,0 +1,1204 @@ +# UnifoLM World Model Action - 模型架构详细分析 + +## 目录 +1. [整体架构概览](#整体架构概览) +2. [推理流程分析](#推理流程分析) +3. [核心组件详解](#核心组件详解) +4. [性能瓶颈分析](#性能瓶颈分析) +5. [内核融合优化建议](#内核融合优化建议) + +--- + +## 1. 整体架构概览 + +### 1.1 模型层次结构 + +``` +DDPM (顶层模型) +├── DiffusionWrapper (条件包装器) +│ └── UNet3D (核心扩散模型) +│ ├── 时间嵌入 (Time Embedding) +│ ├── 下采样块 (Downsampling Blocks) +│ ├── 中间块 (Middle Blocks) +│ └── 上采样块 (Upsampling Blocks) +├── VAE (变分自编码器) +│ ├── Encoder (编码器) +│ └── Decoder (解码器) +├── CLIP Image Encoder (图像编码器) +├── Text Encoder (文本编码器) +├── State Projector (状态投影器) +└── Action Projector (动作投影器) +``` + +### 1.2 推理阶段数据流 + +``` +输入观测 (Observation) + ↓ +[1] 条件编码阶段 + ├── 图像 → CLIP Encoder → Image Embedding + ├── 图像 → VAE Encoder → Latent Condition + ├── 文本 → Text Encoder → Text Embedding + ├── 状态 → State Projector → State Embedding + └── 动作 → Action Projector → Action Embedding + ↓ +[2] DDIM采样阶段 (n步迭代) + ├── 初始化噪声 x_T + └── For step in [0, n]: + ├── 模型前向传播 (UNet3D) + │ ├── 时间步嵌入 + │ ├── 条件注入 (CrossAttention) + │ └── 噪声预测 + ├── DDIM更新公式 + └── x_{t-1} = f(x_t, noise_pred) + ↓ +[3] VAE解码阶段 + └── Latent → VAE Decoder → 视频帧 +``` + +--- + +## 2. 推理流程分析 + +### 2.1 阶段1: 生成动作 (sim_mode=False) + +**目的**: 根据观测和指令生成动作序列 + +**输入**: +- `observation.images.top`: 历史图像观测 [B, C, T_obs, H, W] +- `observation.state`: 历史状态 [B, T_obs, state_dim] +- `action`: 历史动作 [B, T_action, action_dim] +- `instruction`: 文本指令 + +**输出**: +- `pred_videos`: 预测视频 [B, C, T, H, W] +- `pred_actions`: 预测动作序列 [B, T, action_dim] + +**关键特点**: +- 动作条件被置零 (`cond_action_emb = torch.zeros_like(...)`) +- 使用文本指令作为主要引导 + +### 2.2 阶段2: 世界模型交互 (sim_mode=True) + +**目的**: 根据动作预测未来观测 + +**输入**: +- `observation.images.top`: 当前图像 +- `observation.state`: 当前状态 +- `action`: 阶段1生成的动作序列 + +**输出**: +- `pred_videos`: 预测的未来视频帧 +- `pred_states`: 预测的未来状态 + +**关键特点**: +- 不使用文本指令 (`text_input=False`) +- 动作条件被实际使用 + +--- + +## 3. 核心组件详解 + +### 3.1 DDIM采样器 (DDIMSampler) + +**代码位置**: [src/unifolm_wma/models/samplers/ddim.py](src/unifolm_wma/models/samplers/ddim.py) + +**核心方法**: `ddim_sampling()` (第168-300行) + +**实际代码结构**: +```python +def ddim_sampling(self, cond, shape, x_T=None, ddim_steps=50, ...): + # 初始化 + timesteps = self.ddim_timesteps[:ddim_steps] + x = x_T if x_T is not None else torch.randn(shape, device=device) + + # 主循环 + for i, step in enumerate(iterator): + # 获取时间步 + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + # 模型前向传播 (核心瓶颈) + outs = self.p_sample_ddim(x, cond, ts, index=index, ...) + x, pred_x0 = outs + + return x +``` + +**性能数据** (来自profiling): +- 单步去噪总耗时: 10.71s - 11.06s (22次调用) +- 模型前向: 325.30s (660次调用, 平均0.493s/次) +- DDIM更新: 0.21s (660次调用, 平均0.0003s/次) + +### 3.2 DiffusionWrapper (条件路由器) + +**代码位置**: [src/unifolm_wma/models/ddpms.py:2413-2524](src/unifolm_wma/models/ddpms.py) + +**作用**: 将输入和条件路由到内部扩散模型 + +**实际代码** (第2469-2479行): +```python +elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) # 拼接latent条件 + cc = torch.cat(c_crossattn, 1) # 拼接cross-attention条件 + cc_action = c_crossattn_action + out = self.diffusion_model(xc, x_action, x_state, t, + context=cc, context_action=cc_action, **kwargs) +``` + +**条件类型**: +1. **c_concat**: 通道拼接条件 (VAE编码的图像) +2. **c_crossattn**: 交叉注意力条件 (文本、图像、状态、动作embedding) +3. **c_crossattn_action**: 动作头专用条件 + +### 3.3 WMAModel (核心扩散模型) + +**代码位置**: [src/unifolm_wma/modules/networks/wma_model.py:326-849](src/unifolm_wma/modules/networks/wma_model.py) + +**配置文件**: [configs/inference/world_model_interaction.yaml:69-104](configs/inference/world_model_interaction.yaml) + +**实际配置参数**: +```yaml +in_channels: 8 # 输入通道 (4 latent + 4 VAE条件) +out_channels: 4 # 输出通道 +model_channels: 320 # 基础通道数 +channel_mult: [1, 2, 4, 4] # 通道倍增: [320, 640, 1280, 1280] +num_res_blocks: 2 # 每个分辨率2个ResBlock +attention_resolutions: [4, 2, 1] # 在这些分辨率启用注意力 +num_head_channels: 64 # 每个注意力头64通道 +transformer_depth: 1 # Transformer深度 +context_dim: 1024 # 交叉注意力上下文维度 +temporal_length: 16 # 时间序列长度 +``` + +**架构层次** (详见附录A.1): +- 4个下采样阶段 (每阶段2个ResBlock + Attention) +- 1个中间块 (2个ResBlock + Attention) +- 3个上采样阶段 (每阶段2个ResBlock + Attention) +- 总计: 16个ResBlock + 32个Transformer + +### 3.4 VAE (变分自编码器) + +**代码位置**: [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py) + +**配置文件**: [configs/inference/world_model_interaction.yaml:159-180](configs/inference/world_model_interaction.yaml) + +**实际配置参数**: +```yaml +AutoencoderKL: + embed_dim: 4 # Latent维度 + z_channels: 4 # Latent通道数 + in_channels: 3 # RGB输入 + out_ch: 3 # RGB输出 + ch: 128 # 基础通道数 + ch_mult: [1, 2, 4, 4] # 通道倍增: [128, 256, 512, 512] + num_res_blocks: 2 # 每层2个ResBlock + attn_resolutions: [] # VAE中不使用注意力 +``` + +**编码/解码过程**: +```python +# 编码: [B, 3, 320, 512] → [B, 4, 40, 64] (8×8下采样) +z = model.encode_first_stage(img) + +# 解码: [B, 4, 40, 64] → [B, 3, 320, 512] +video = model.decode_first_stage(samples) +``` + +**性能数据**: +- VAE编码: 1.03s (22次, 平均0.047s/次) +- VAE解码: 15.53s (22次, 平均0.706s/次) +- 压缩比: 8×8 = 64倍空间压缩 + +**详细架构**: 见附录A.4 + +### 3.5 条件编码器 + +#### 3.5.1 CLIP图像编码器 + +**代码位置**: [src/unifolm_wma/modules/encoders/condition.py](src/unifolm_wma/modules/encoders/condition.py) - `FrozenOpenCLIPImageEmbedderV2` + +**配置文件**: [configs/inference/world_model_interaction.yaml:188-204](configs/inference/world_model_interaction.yaml) + +**实际配置**: +```yaml +FrozenOpenCLIPImageEmbedderV2: + freeze: true + # 使用OpenCLIP ViT-H/14 + # 输出: [B, 1280] + +Resampler (图像投影器): + dim: 1024 # 输出维度 + depth: 4 # Transformer深度 + heads: 12 # 12个注意力头 + num_queries: 16 # 16个查询token + embedding_dim: 1280 # CLIP输出维度 +``` + +**数据流**: 图像 [B, 3, H, W] → CLIP → [B, 1280] → Resampler → [B, 16, 1024] + +**性能**: 0.71s (22次, 平均0.032s/次) + +#### 3.5.2 文本编码器 + +**代码位置**: [src/unifolm_wma/modules/encoders/condition.py](src/unifolm_wma/modules/encoders/condition.py) - `FrozenOpenCLIPEmbedder` + +**配置文件**: [configs/inference/world_model_interaction.yaml:182-186](configs/inference/world_model_interaction.yaml) + +**实际配置**: +```yaml +FrozenOpenCLIPEmbedder: + freeze: True + layer: "penultimate" # 使用倒数第二层 + # 输出: [B, seq_len, 1024] +``` + +**性能**: 0.13s (22次, 平均0.006s/次) + +#### 3.5.3 状态投影器 + +**代码位置**: [src/unifolm_wma/models/ddpms.py:2014-2026](src/unifolm_wma/models/ddpms.py) - `MLPProjector` + +**MLPProjector实现** (src/unifolm_wma/utils/projector.py:14-37): +```python +class MLPProjector(nn.Module): + def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"): + if mlp_type == "gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(input_dim, output_dim, bias=True), + nn.GELU(approximate='tanh'), + nn.Linear(output_dim, output_dim, bias=True), + ) +``` + +**数据流**: 状态 [B, T_obs, 16] → MLPProjector → [B, T_obs, 1024] + agent_state_pos_emb + +**性能**: 0.006s (22次, 平均0.0003s/次) + +#### 3.5.4 动作投影器 + +**代码位置**: [src/unifolm_wma/models/ddpms.py:2020-2024](src/unifolm_wma/models/ddpms.py) - `MLPProjector` + +**数据流**: 动作 [B, T_action, 16] → MLPProjector → [B, T_action, 1024] + agent_action_pos_emb + +**位置嵌入定义**: +```python +# ddpms.py:2023-2026 +self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024)) +self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024)) +``` + +**性能**: 0.003s (22次, 平均0.0001s/次) + +--- + +## 4. 性能瓶颈分析 + +### 4.1 时间分布 (总计412.39s) + +根据性能分析数据,时间分布如下: + +| 阶段 | 总耗时 | 占比 | 说明 | +|------|--------|------|------| +| 阶段1: 生成动作 | 171.52s | 41.6% | DDIM采样30步 | +| 阶段2: 世界模型交互 | 171.65s | 41.6% | DDIM采样30步 | +| 模型加载 | 47.56s | 11.5% | 一次性开销 | +| 保存视频 | 13.91s | 3.4% | I/O操作 | +| 保存完整视频 | 7.22s | 1.8% | I/O操作 | +| 数据集加载 | 0.51s | 0.1% | 一次性开销 | + +### 4.2 DDIM采样详细分析 + +**DDIM采样是绝对瓶颈,占总时间的94.9%** + +``` +DDIM采样总耗时: 325.74s +├── 模型前向传播: 325.30s (99.86%) ← 核心瓶颈 +├── DDIM更新公式: 0.21s (0.06%) +└── Action/State调度: 0.13s (0.04%) +``` + +**每步耗时分析**: +- 30个去噪步骤,每步平均耗时: 10.86s +- 每步调用模型前向2次 (阶段1和阶段2各1次) +- 每次前向传播: ~0.493s + +### 4.3 瓶颈总结 + +**关键发现**: +1. **模型前向传播占99.86%的DDIM时间** - 这是优化的核心目标 +2. VAE解码占4.5%总时间 - 次要优化目标 +3. 其他操作(条件编码、DDIM更新)耗时可忽略 + +--- + +## 5. 内核融合优化建议 + +### 5.1 优化策略概览 + +基于性能分析,优化应聚焦于: +1. **UNet3D模型前向传播** (最高优先级) +2. **VAE解码器** (次要优先级) +3. **批处理和并行化** (辅助优化) + +### 5.2 WMAModel内核融合机会 + +#### 5.2.1 时间步嵌入融合 + +**代码位置**: [src/unifolm_wma/utils/diffusion.py](src/unifolm_wma/utils/diffusion.py) - `timestep_embedding()` + +**当前实现** (实际代码): +```python +# 1. 正弦位置编码 +def timestep_embedding(timesteps, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(0, half) / half) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + return embedding + +# 2. 时间嵌入网络 (在WMAModel.__init__中) +self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), # 320 → 1280 + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), # 1280 → 1280 +) +``` + +**融合机会**: +- `Linear + SiLU + Linear` 可融合为单个kernel +- 正弦编码计算可与第一个Linear融合 + +**预期收益**: 减少2-3次kernel启动开销 + +#### 5.2.2 ResBlock内核融合 + +**代码位置**: [src/unifolm_wma/modules/networks/wma_model.py:130-263](src/unifolm_wma/modules/networks/wma_model.py) - `class ResBlock` + +**当前实现** (实际代码): +```python +# in_layers: GroupNorm + SiLU + Conv +self.in_layers = nn.Sequential( + normalization(channels), # GroupNorm + nn.SiLU(), + conv_nd(dims, channels, out_channels, 3, padding=1) +) + +# emb_layers: SiLU + Linear +self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(emb_channels, out_channels) +) + +# out_layers: GroupNorm + SiLU + Dropout + Conv +self.out_layers = nn.Sequential( + normalization(out_channels), # GroupNorm + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1)) +) +``` + +**融合机会**: +1. `GroupNorm + SiLU` 可融合 (in_layers和out_layers各一次) +2. `emb_layers` 的 `SiLU + Linear` 可融合 +3. 残差加法可与下一层的GroupNorm融合 + +**实际瓶颈**: 16个ResBlock × 30步 × 2次 = **960次ResBlock调用** + +**预期收益**: 每个ResBlock节省50-60%的kernel启动开销 + +#### 5.2.3 注意力机制优化 + +**代码位置**: [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py) + +**实际配置**: +- SpatialTransformer: 空间维度注意力 +- TemporalTransformer: 时间维度注意力 +- 总计: 32个Transformer × 30步 × 2次 = **1920次注意力调用** + +**优化方案**: +使用 PyTorch 内置的 Flash Attention: +```python +from torch.nn.functional import scaled_dot_product_attention + +# 替换标准注意力计算 +out = scaled_dot_product_attention(Q, K, V, is_causal=False) +``` + +**预期收益**: 注意力层加速2-3倍,整体加速30-40% + +### 5.3 VAE解码器优化 + +**代码位置**: [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py) + +**当前性能**: 15.53s (22次调用, 平均0.706s/次) + +**优化方案**: +1. **混合精度**: 使用FP16进行解码 + ```python + with torch.cuda.amp.autocast(): + video = vae.decode(latent) + ``` + +2. **批处理优化**: 确保VAE解码使用批处理而非逐帧 + +**预期收益**: 加速20-30% + +### 5.4 实施建议 + +#### 5.4.1 使用 torch.compile() (最简单) + +**代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) + +**实际实施位置**: +在模型加载后添加: + +```python +# 在模型加载并移动到GPU后添加 +config = OmegaConf.load(args.config) +model = instantiate_from_config(config.model) +model = load_model_checkpoint(model, args.ckpt_path) +model.eval() +model = model.cuda() + +# 添加 torch.compile() 优化 +model.model.diffusion_model = torch.compile( + model.model.diffusion_model, + mode='max-autotune', # 或 'reduce-overhead' + fullgraph=True +) +``` + +**优点**: +- 无需修改模型代码 +- 自动融合操作 +- 支持动态shape + +**预期收益**: 20-40%加速 + +#### 5.4.2 使用 Flash Attention + +**代码位置**: [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py) + +**当前实现分析**: +代码已经支持 xformers (`xformers.ops.memory_efficient_attention`)。当 xformers 可用时,`CrossAttention` 类会自动使用 `efficient_forward` 方法: + +```python +# attention.py:90-91 +if XFORMERS_IS_AVAILBLE and temporal_length is None: + self.forward = self.efficient_forward +``` + +**进一步优化方案**: +如果 xformers 不可用,可以使用 PyTorch 内置的 Flash Attention: +```python +from torch.nn.functional import scaled_dot_product_attention +out = scaled_dot_product_attention(q, k, v, is_causal=False) +``` + +**预期收益**: 注意力层加速2-3倍 + + +#### 5.4.3 混合精度推理 + +**代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) + +**实际实施位置**: +在推理调用处添加混合精度上下文: + +```python +# 在 image_guided_synthesis_sim_mode 调用处添加 +with torch.cuda.amp.autocast(): + pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode( + model, sample['instruction'], observation, noise_shape, + action_cond_step=args.exe_steps, + ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args.unconditional_guidance_scale, + fs=model_input_fs, timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, sim_mode=False) +``` + +**注意事项**: +- 模型会自动在FP16和FP32之间切换 +- 某些操作(如LayerNorm)会自动保持FP32精度 +- 无需手动转换模型权重 + +**预期收益**: 30-50%加速 + 减少50%显存 + +### 5.5 优化路线图 + +#### 阶段1: 快速优化 + +**目标**: 获得20-40%加速,无需修改模型代码 + +**实施步骤**: +1. 启用 `torch.compile()` - 在模型加载后添加 +2. 启用 `torch.backends.cudnn.benchmark = True` - 在推理开始前设置 +3. 使用混合精度推理 (FP16) - 在推理调用处添加 + +**实施代码**: +```python +# 在推理函数开始处添加 +torch.backends.cudnn.benchmark = True + +# 在模型加载后添加 torch.compile() +model = model.cuda() +model.model.diffusion_model = torch.compile( + model.model.diffusion_model, + mode='max-autotune' +) + +# 在推理循环中使用混合精度 +with torch.cuda.amp.autocast(): + pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...) +``` + +#### 阶段2: 中级优化 + +**目标**: 获得50-70%加速 + +**实施步骤**: +1. 确保 xformers 已安装并启用 - 检查 `XFORMERS_IS_AVAILBLE` 标志 +2. 优化VAE解码器批处理 - 检查 [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py) 中的 `decode()` 方法 +3. 分析并优化内存访问模式 - 使用 `torch.cuda.memory_stats()` 分析 + +**关键修改点**: +- 确认 xformers 已正确安装: `pip install xformers` +- 在 `CrossAttention` 类中,当 xformers 可用时会自动使用 `efficient_forward` +- 确保VAE解码使用批处理而非逐帧处理 + +#### 阶段3: 深度优化 + +**目标**: 获得2-3倍加速 + +**实施步骤**: +1. 自定义CUDA kernel融合关键操作 + - 融合 GroupNorm + SiLU + Conv (在 [src/unifolm_wma/modules/networks/wma_model.py:130-263](src/unifolm_wma/modules/networks/wma_model.py) 的 ResBlock 中) +2. 优化卷积操作 + - 分析 Conv2D 操作的性能 (模型实际使用 Conv2D 而非 Conv3D) +3. 优化数据加载和预处理pipeline + - 检查 [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) 的数据准备部分 + +**需要的技能**: +- CUDA编程 +- PyTorch C++扩展 +- 性能分析工具 (Nsight Systems, nvprof) + + +### 5.6 预期总体收益 + +基于以上优化策略和实际代码分析,预期性能提升: + +| 优化阶段 | 预期加速比 | 实施难度 | 主要修改文件 | +|---------|-----------|---------|-------------| +| 阶段1: 快速优化 | 1.2-1.4x | 低 | [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) | +| 阶段2: 中级优化 | 1.5-1.7x | 中 | [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py), [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py) | +| 阶段3: 深度优化 | 2.0-3.0x | 高 | [src/unifolm_wma/modules/networks/wma_model.py](src/unifolm_wma/modules/networks/wma_model.py) + 自定义CUDA kernel | + +**总体目标**: 通过系统性优化实现 2-3倍加速 + +--- + +## 6. 关键代码位置索引 + +为方便内核融合实施,以下是关键代码位置: + +### 6.1 核心模型文件 + +| 组件 | 文件路径 | 关键类/函数 | +|------|---------|-----------| +| DDPM主模型 | `src/unifolm_wma/models/ddpms.py` | `class DDPM` | +| 条件包装器 | `src/unifolm_wma/models/ddpms.py:2413` | `class DiffusionWrapper` | +| DDIM采样器 | `src/unifolm_wma/models/samplers/ddim.py` | `class DDIMSampler` | +| VAE编解码 | `src/unifolm_wma/models/autoencoder.py` | `encode_first_stage`, `decode_first_stage` | + +### 6.2 推理脚本 + +| 文件 | 说明 | +|------|------| +| `scripts/evaluation/world_model_interaction.py` | 推理脚本 | + +### 6.3 配置文件 + +| 文件 | 说明 | +|------|------| +| `configs/inference/world_model_interaction.yaml` | 推理配置 | +| `unitree_g1_pack_camera/case1/run_world_model_interaction.sh` | 运行脚本 | + + +--- + +## 7. 下一步行动建议 + +### 7.1 立即可执行的优化 + +**最小改动,最大收益**: + +1. **启用 torch.compile()** + - **代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) + - **修改**: 在模型加载后添加 + ```python + # 在模型加载并移动到GPU后添加 + model.model.diffusion_model = torch.compile( + model.model.diffusion_model, + mode='max-autotune' + ) + ``` + +2. **启用 cuDNN benchmark** + - **代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) + - **修改**: 在推理函数开始处添加 + ```python + torch.backends.cudnn.benchmark = True + ``` + +3. **混合精度推理** + - **代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) + - **修改**: 在 `image_guided_synthesis_sim_mode` 调用处添加 + ```python + with torch.cuda.amp.autocast(): + pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...) + ``` + +**预期收益**: 20-40%加速,无风险 + + +### 7.2 需要深入探索的部分 + +为了更精确的优化,建议进一步分析: + +1. **注意力层的具体实现** + - **代码位置**: [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py) + - **分析目标**: + - `CrossAttention` 类 (第48-398行) - 核心注意力实现 + - `BasicTransformerBlock` 类 (第400-469行) - Transformer块 + - 确认 xformers 是否已启用 (`XFORMERS_IS_AVAILBLE` 标志) + +2. **ResBlock的详细结构** + - **代码位置**: [src/unifolm_wma/modules/networks/wma_model.py:130-263](src/unifolm_wma/modules/networks/wma_model.py) + - **分析目标**: + - 确认 GroupNorm + SiLU + Conv 的调用顺序 + - 识别可以融合的操作序列 + - 评估自定义 CUDA kernel 的可行性 + +3. **内存瓶颈分析** + - **分析工具**: 使用 `torch.cuda.memory_stats()` 和 `torch.profiler` + - **分析位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) + - **分析目标**: + - 识别内存拷贝热点 + - 优化中间张量的生命周期 + - 减少不必要的内存分配 + +4. **计算瓶颈定位** + - **分析工具**: Nsight Systems 或 PyTorch Profiler + - **分析目标**: + - 识别 kernel 启动开销 + - 分析 GPU 利用率 + - 找到计算密集型操作 + +--- + +## 8. 参考资料 + +### 8.1 优化技术文档 + +- [PyTorch 2.0 torch.compile()](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) +- [Flash Attention](https://github.com/Dao-AILab/flash-attention) +- [CUDA Kernel Fusion](https://developer.nvidia.com/blog/cuda-pro-tip-increase-performance-with-vectorized-memory-access/) + +### 8.2 相关论文 + +- DDIM: Denoising Diffusion Implicit Models +- Flash Attention: Fast and Memory-Efficient Exact Attention +- Efficient Diffusion Models for Vision + +--- + +## 9. 总结 + +### 9.1 关键发现 + +1. **模型前向传播占99.86%的DDIM采样时间** - 这是优化的绝对核心 +2. **30步DDIM采样占总时间的83%** - 减少步数或加速单步是关键 +3. **VAE解码占4.5%** - 次要优化目标 + +### 9.2 优化优先级 + +**高优先级** (立即实施): +- ✅ torch.compile() +- ✅ cuDNN benchmark +- ✅ 混合精度推理 + +**中优先级** (1周内): +- Flash Attention集成 +- VAE批处理优化 + +**低优先级** (长期): +- 自定义CUDA kernel +- 模型架构改进 + +### 9.3 预期成果 + +通过系统性优化,预期可以将推理时间从 **412s 降低到 140-200s**,实现 **2-3倍加速**。 + +--- + +**文档版本**: v1.1 +**创建日期**: 2026-01-17 +**最后更新**: 2026-01-17 +**更新内容**: 根据实际代码验证并修正了文件路径、行号、组件位置和实现细节 + + +--- + +## 附录A: 实际模型架构详解 + +基于代码分析,以下是真实的模型实现细节。 + +### A.1 WMAModel 实际配置 + +**配置文件**: `configs/inference/world_model_interaction.yaml:69-104` + +```yaml +WMAModel参数: + in_channels: 8 # 输入通道 (4 latent + 4 concat条件) + out_channels: 4 # 输出通道 (latent空间) + model_channels: 320 # 基础通道数 + channel_mult: [1, 2, 4, 4] # 通道倍增: [320, 640, 1280, 1280] + num_res_blocks: 2 # 每个分辨率2个ResBlock + attention_resolutions: [4, 2, 1] # 在这些分辨率启用注意力 + num_head_channels: 64 # 每个注意力头64通道 + transformer_depth: 1 # Transformer深度 + context_dim: 1024 # 交叉注意力上下文维度 + temporal_length: 16 # 时间序列长度 + dropout: 0.1 +``` + +**架构层次**: +``` +输入: [B, 8, 16, 40, 64] (8通道 = 4 latent + 4 VAE条件) + ↓ +下采样路径 (4个阶段): + Stage 0: [B, 320, 16, 40, 64] - 2个ResBlock + SpatialTransformer + TemporalTransformer + Stage 1: [B, 640, 16, 20, 32] - Downsample + 2个ResBlock + Attention + Stage 2: [B, 1280, 16, 10, 16] - Downsample + 2个ResBlock + Attention + Stage 3: [B, 1280, 16, 5, 8] - Downsample + 2个ResBlock + Attention + ↓ +中间块: [B, 1280, 16, 5, 8] + - ResBlock + SpatialTransformer + TemporalTransformer + ResBlock + ↓ +上采样路径 (3个阶段): + Stage 2: [B, 1280, 16, 10, 16] - Upsample + 2个ResBlock + Attention + Stage 1: [B, 640, 16, 20, 32] - Upsample + 2个ResBlock + Attention + Stage 0: [B, 320, 16, 40, 64] - Upsample + 2个ResBlock + Attention + ↓ +输出: [B, 4, 16, 40, 64] (预测的噪声或速度) +``` + + +### A.2 ResBlock 实际实现 + +**位置**: `src/unifolm_wma/modules/networks/wma_model.py:130-263` + +**实际代码结构**: +```python +class ResBlock: + def __init__(self, channels, emb_channels, dropout, ...): + # 输入层: GroupNorm + SiLU + Conv + self.in_layers = nn.Sequential( + normalization(channels), # GroupNorm + nn.SiLU(), # 激活函数 + conv_nd(dims, channels, out_channels, 3, padding=1) + ) + + # 时间步嵌入层: SiLU + Linear + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(emb_channels, out_channels) + ) + + # 输出层: GroupNorm + SiLU + Dropout + Conv + self.out_layers = nn.Sequential( + normalization(out_channels), # GroupNorm + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1)) + ) + + # 残差连接 + self.skip_connection = ... + + # 可选的时间卷积 + if use_temporal_conv: + self.temporal_conv = TemporalConvBlock(...) + + def forward(self, x, emb): + h = self.in_layers(x) # GroupNorm + SiLU + Conv + emb_out = self.emb_layers(emb) # 时间步嵌入 + h = h + emb_out # 加入时间步信息 + h = self.out_layers(h) # GroupNorm + SiLU + Dropout + Conv + h = self.skip_connection(x) + h # 残差连接 + + if use_temporal_conv: + h = self.temporal_conv(h) # 时间维度卷积 + return h +``` + +**内核融合机会**: +1. `GroupNorm + SiLU` 可融合 (in_layers和out_layers各一次) +2. `emb_layers(SiLU + Linear)` 可融合 +3. `残差加法 + 下一层的GroupNorm` 可融合 + + +### A.3 注意力机制实际实现 + +**SpatialTransformer** (空间注意力): +- 位置: `src/unifolm_wma/modules/attention.py:472-558` +- 在特征图的空间维度(H×W)上执行自注意力和交叉注意力 +- 使用 `transformer_depth=1`,即每个位置1层Transformer +- 当 xformers 可用时,使用 `efficient_forward` 方法进行高效注意力计算 + +**TemporalTransformer** (时间注意力): +- 位置: `src/unifolm_wma/modules/attention.py:561-680` +- 在时间维度(T=16帧)上执行自注意力 +- 配置: `temporal_selfatt_only=True` (仅时间自注意力,不做交叉注意力) +- 使用相对位置编码: `use_relative_position=False` (实际未启用) + +**CrossAttention** (核心注意力层): +- 位置: `src/unifolm_wma/modules/attention.py:48-398` +- 支持多种交叉注意力: 图像、文本、状态、动作 +- 当 xformers 可用时自动使用 `xformers.ops.memory_efficient_attention` + +**注意力头配置**: +- `num_head_channels=64`: 每个头64通道 +- 对于320通道: 320/64 = 5个注意力头 +- 对于640通道: 640/64 = 10个注意力头 +- 对于1280通道: 1280/64 = 20个注意力头 + + +### A.4 VAE 实际配置 + +**配置文件**: `configs/inference/world_model_interaction.yaml:159-180` + +```yaml +AutoencoderKL: + embed_dim: 4 # Latent维度 + z_channels: 4 # Latent通道数 + resolution: 256 # 基础分辨率 + in_channels: 3 # RGB输入 + out_ch: 3 # RGB输出 + ch: 128 # 基础通道数 + ch_mult: [1, 2, 4, 4] # 通道倍增 + num_res_blocks: 2 # 每层2个ResBlock + attn_resolutions: [] # VAE中不使用注意力 + dropout: 0.0 +``` + +**编码器架构**: +``` +输入: [B, 3, 320, 512] + ↓ Conv 3→128 + ↓ ResBlock×2 [128, 320, 512] + ↓ Downsample [128, 160, 256] + ↓ ResBlock×2 [256, 160, 256] + ↓ Downsample [256, 80, 128] + ↓ ResBlock×2 [512, 80, 128] + ↓ Downsample [512, 40, 64] + ↓ ResBlock×2 [512, 40, 64] + ↓ ResBlock + Conv +输出: [B, 4, 40, 64] (8×8下采样) +``` + +**解码器架构** (编码器的镜像): +``` +输入: [B, 4, 40, 64] + ↓ Conv + ResBlock + ↓ ResBlock×2 [512, 40, 64] + ↓ Upsample [512, 80, 128] + ↓ ResBlock×2 [512, 80, 128] + ↓ Upsample [256, 160, 256] + ↓ ResBlock×2 [256, 160, 256] + ↓ Upsample [128, 320, 512] + ↓ ResBlock×2 [128, 320, 512] + ↓ Conv 128→3 +输出: [B, 3, 320, 512] +``` + + +### A.5 条件编码器实际配置 + +#### CLIP图像编码器 + +**配置**: `configs/inference/world_model_interaction.yaml:188-191` +```yaml +FrozenOpenCLIPImageEmbedderV2: + freeze: true + # 使用OpenCLIP的ViT-H/14模型 + # 输出维度: 1280 +``` + +**图像投影器 (Resampler)**: +```yaml +Resampler: + dim: 1024 # 输出维度 + depth: 4 # Transformer深度 + dim_head: 64 # 注意力头维度 + heads: 12 # 12个注意力头 + num_queries: 16 # 16个查询token + embedding_dim: 1280 # CLIP输出维度 + output_dim: 1024 # 最终输出维度 + video_length: 16 # 视频长度 +``` + +**数据流**: +``` +图像 [B, 3, H, W] + ↓ CLIP Encoder + ↓ [B, 1280] + ↓ Resampler (Perceiver-style) + ↓ [B, 16, 1024] (16个token,每个1024维) +``` + + +#### 文本编码器 + +**配置**: `configs/inference/world_model_interaction.yaml:182-186` +```yaml +FrozenOpenCLIPEmbedder: + freeze: True + layer: "penultimate" # 使用倒数第二层 + # 输出维度: 1024 +``` + +**数据流**: +``` +文本指令 "pick up the box" + ↓ OpenCLIP Text Encoder + ↓ [B, seq_len, 1024] +``` + +#### 动作/状态投影器 + +**代码位置**: [src/unifolm_wma/models/ddpms.py:2014-2026](src/unifolm_wma/models/ddpms.py) + +**MLPProjector实现** (src/unifolm_wma/utils/projector.py:14-37): +```python +class MLPProjector(nn.Module): + def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"): + if mlp_type == "gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(input_dim, output_dim, bias=True), + nn.GELU(approximate='tanh'), + nn.Linear(output_dim, output_dim, bias=True), + ) +``` + +**初始化代码** (ddpms.py:2014-2026): +```python +# 状态投影器 +self.state_projector = MLPProjector(agent_state_dim, 1024) # 16 → 1024 +self.action_projector = MLPProjector(agent_action_dim, 1024) # 16 → 1024 + +# 位置嵌入 +self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024)) +self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024)) +``` + +**数据流**: +``` +状态 [B, T_obs, 16] + ↓ MLPProjector (Linear + GELU + Linear) + ↓ [B, T_obs, 1024] + ↓ + agent_state_pos_emb + ↓ [B, T_obs, 1024] + +动作 [B, T_action, 16] + ↓ MLPProjector (Linear + GELU + Linear) + ↓ [B, T_action, 1024] + ↓ + agent_action_pos_emb + ↓ [B, T_action, 1024] +``` + + +### A.6 时间步嵌入实际实现 + +**位置**: `src/unifolm_wma/utils/diffusion.py:timestep_embedding` + +**实际代码**: +```python +def timestep_embedding(timesteps, dim, max_period=10000): + """ + 创建正弦位置编码 + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half) / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + return embedding +``` + +**在WMAModel中的使用**: +```python +# 时间步 t ∈ [0, 999] +t_emb = timestep_embedding(t, model_channels) # [B, 320] +t_emb = self.time_embed(t_emb) # Linear(320 → 1280) +# 输出: [B, 1280] +``` + +**时间嵌入网络**: +``` +t ∈ [0, 999] + ↓ timestep_embedding (正弦编码) + ↓ [B, 320] + ↓ Linear(320 → 1280) + ↓ SiLU + ↓ Linear(1280 → 1280) + ↓ [B, 1280] +``` + + +### A.7 动作头 (ConditionalUnet1D) 实际配置 + +**配置**: `configs/inference/world_model_interaction.yaml:106-127` + +```yaml +ConditionalUnet1D: + input_dim: 16 # 动作维度 + n_obs_steps: 2 # 观测步数 + diffusion_step_embed_dim: 128 # 扩散步嵌入维度 + down_dims: [256, 512, 1024, 2048] # 下采样通道 + kernel_size: 5 # 卷积核大小 + n_groups: 8 # GroupNorm分组数 + horizon: 16 # 预测时间范围 + use_linear_attn: true # 使用线性注意力 + imagen_cond_gradient: true # 使用图像条件梯度 +``` + +**架构**: +``` +输入: 噪声动作 [B, 16, 16] (16维动作 × 16步) +条件: + - 图像特征 [B, C, H, W] 来自WMAModel中间层 + - 观测编码 [B, n_obs, obs_dim] + + ↓ Conv1D + ResBlock + ↓ [B, 256, 16] + ↓ Downsample + ResBlock + ↓ [B, 512, 8] + ↓ Downsample + ResBlock + ↓ [B, 1024, 4] + ↓ Downsample + ResBlock + ↓ [B, 2048, 2] + ↓ Middle Block (with attention) + ↓ Upsample + ResBlock + ↓ [B, 1024, 4] + ↓ Upsample + ResBlock + ↓ [B, 512, 8] + ↓ Upsample + ResBlock + ↓ [B, 256, 16] + ↓ Conv1D +输出: 预测噪声 [B, 16, 16] +``` + + +### A.8 完整前向传播流程 + +基于实际代码,完整的前向传播流程如下: + +```python +# 1. 条件编码 (一次性完成,可缓存) +cond_img_emb = clip_encoder(img) → resampler → [B, 16, 1024] +cond_text_emb = text_encoder(text) → [B, seq_len, 1024] +cond_state_emb = state_projector(state) + pos_emb → [B, T_obs, 1024] +cond_action_emb = action_projector(action) + pos_emb → [B, T_action, 1024] +cond_latent = vae.encode(img) → [B, 4, T, 40, 64] + +# 2. 拼接条件 +c_concat = [cond_latent] # 通道拼接 +c_crossattn = [cond_text_emb, cond_img_emb, cond_state_emb, cond_action_emb] +c_crossattn = torch.cat(c_crossattn, dim=1) # [B, total_tokens, 1024] + +# 3. DDIM采样循环 (30步) +x = torch.randn([B, 4, 16, 40, 64]) # 初始噪声 +for step in range(30): + # 3.1 时间步嵌入 + t_emb = timestep_embedding(t, 320) → Linear → [B, 1280] + + # 3.2 拼接输入 + x_in = torch.cat([x, cond_latent], dim=1) # [B, 8, 16, 40, 64] + + # 3.3 UNet前向传播 (核心瓶颈) + noise_pred = wma_model(x_in, t_emb, c_crossattn) + # 包含: 4个下采样阶段 + 中间块 + 3个上采样阶段 + # 每个阶段: 2个ResBlock + SpatialTransformer + TemporalTransformer + + # 3.4 DDIM更新 + x = ddim_update(x, noise_pred, t, t_prev) + +# 4. VAE解码 +video = vae.decode(x) → [B, 3, 16, 320, 512] +``` + + +### A.9 基于实际架构的优化建议更新 + +#### 优化点1: ResBlock融合 (高优先级) + +**实际瓶颈**: +- 每个DDIM步骤调用UNet一次 +- UNet包含: 4个下采样阶段 + 1个中间块 + 3个上采样阶段 = 8个阶段 +- 每个阶段有2个ResBlock +- 总计: 16个ResBlock × 30步 × 2次(阶段1+2) = **960次ResBlock调用** + +**融合机会**: +```python +# 当前: 6次kernel启动 +h = group_norm(x) # kernel 1 +h = silu(h) # kernel 2 +h = conv2d(h) # kernel 3 +h = group_norm(h) # kernel 4 +h = silu(h) # kernel 5 +h = conv2d(h) # kernel 6 +out = x + h # kernel 7 + +# 优化后: 2-3次kernel启动 +h = fused_norm_silu_conv(x) # kernel 1 (融合) +h = fused_norm_silu_conv(h) # kernel 2 (融合) +out = fused_residual_add(x, h) # kernel 3 (融合) +``` + +**预期收益**: 每个ResBlock节省50-60%的kernel启动开销 + + +#### 优化点2: 注意力机制优化 (高优先级) + +**实际配置**: +- SpatialTransformer: 在每个阶段的每个ResBlock后 +- TemporalTransformer: 在每个阶段的每个ResBlock后 +- 总计: 16个Spatial + 16个Temporal = **32个Transformer × 30步 × 2次 = 1920次注意力调用** + +**当前实现已支持xformers**: +代码在 `attention.py:8-13` 检测 xformers 可用性: +```python +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False +``` + +当 xformers 可用时,`CrossAttention` 会自动使用 `efficient_forward` 方法 (attention.py:90-91)。 + +**进一步优化方案** (如果xformers不可用): +```python +# 使用 PyTorch 内置 Flash Attention +from torch.nn.functional import scaled_dot_product_attention +out = scaled_dot_product_attention(Q, K, V, is_causal=False) +``` + +**预期收益**: 如果xformers已启用,注意力层已经是优化的;否则使用Flash Attention可加速2-3倍 + diff --git a/prepare_data/prepare_training_data.py b/prepare_data/prepare_training_data.py new file mode 100644 index 0000000..3d899e3 --- /dev/null +++ b/prepare_data/prepare_training_data.py @@ -0,0 +1,199 @@ +import json +import os +import shutil +import h5py +import argparse +import pandas as pd +import torch +import subprocess + +from pathlib import Path +from safetensors.torch import save_file +from tqdm import tqdm + + +def flatten_dict(d, parent_key="", sep="/"): + """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. + + For example: + ``` + >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` + >>> print(flatten_dict(dct)) + {"a/b": 1, "a/c/d": 2, "e": 3} + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def is_av1(file_path): + try: + result = subprocess.run([ + "ffprobe", "-v", "error", "-select_streams", "v:0", + "-show_entries", "stream=codec_name", "-of", "csv=p=0", + str(file_path) + ], + capture_output=True, + text=True, + check=True) + return result.stdout.strip() == "av1" + except subprocess.CalledProcessError: + return False + + +def convert_to_h264(input_path, output_path): + subprocess.run([ + "ffmpeg", "-i", + str(input_path), "-c:v", "libx264", "-preset", "slow", "-crf", "23", + "-c:a", "copy", + str(output_path) + ], + check=True) + + +def main(args): + source_dir = Path(args.source_dir) + source_data_dir = source_dir / args.dataset_name / "data" / "chunk-000" + source_meta_dir = source_dir / args.dataset_name / "meta" + source_videos_dir = source_dir / args.dataset_name / "videos" / "chunk-000" + + target_dir = Path(args.target_dir) + target_videos_dir = target_dir / "videos" / args.dataset_name + target_transitions_dir = target_dir / "transitions" / args.dataset_name + target_meta_dir = target_dir / "transitions" / args.dataset_name / "meta_data" + + target_dir.mkdir(parents=True, exist_ok=True) + target_videos_dir.mkdir(parents=True, exist_ok=True) + target_transitions_dir.mkdir(parents=True, exist_ok=True) + target_meta_dir.mkdir(parents=True, exist_ok=True) + + csv_file = target_dir / f"{args.dataset_name}.csv" + COLUMNS = [ + 'videoid', 'contentUrl', 'duration', 'data_dir', 'instruction', + 'dynamic_confidence', 'dynamic_wording', 'dynamic_source_category', + 'embodiment' + ] + df = pd.DataFrame(columns=COLUMNS) + + # Load info.json from source dir + info_json_path = source_meta_dir / "info.json" + with open(str(info_json_path), "r") as f: + info = json.load(f) + total_episodes = info['total_episodes'] + + # Load task.jsonl to get lanugage ins + tasks_jsonl_path = source_meta_dir / "tasks.jsonl" + with open(str(tasks_jsonl_path), "r") as f: + tasks = [json.loads(line) for line in f] + instruction = tasks[0]['task'] + + source_video_views = [d for d in source_videos_dir.iterdir()] + for v_idx, source_view_dir in enumerate(source_video_views): + + view_name = source_view_dir.name + target_videos_view_dir = target_videos_dir / view_name + target_videos_view_dir.mkdir(parents=True, exist_ok=True) + + if v_idx == 0: + all_actions = [] + all_states = [] + + for idx in tqdm(range(total_episodes)): + # Copy source video to target vidoe dir + source_video = source_view_dir / f"episode_{idx:06d}.mp4" + if is_av1(source_video): + output_video = str(target_videos_view_dir / f"{idx}.mp4") + print(f"Converting episode_{idx:06d}.mp4 to H.264...") + convert_to_h264(source_video, output_video) + else: + print(f"Skipping episode_{idx:06d}.mp4: not AV1 encoded.") + + # Load parquet file + episode_parquet_file = source_data_dir / f"episode_{idx:06d}.parquet" + episode_data = pd.read_parquet(episode_parquet_file) + actions = torch.tensor(episode_data['action'].tolist()) + states = torch.tensor(episode_data['observation.state'].tolist()) + + # Save action and state into a h5 file + if v_idx == 0: + target_h5_file = target_transitions_dir / f"{idx}.h5" + with h5py.File(str(target_h5_file), 'w') as h5f: + h5f.create_dataset('observation.state', data=states) + h5f.create_dataset('action', data=actions) + h5f.attrs['action_type'] = 'joint position' + h5f.attrs['state_type'] = 'joint position' + h5f.attrs['robot_type'] = args.robot_name + + # Updata df + df = pd.concat([ + df, + pd.DataFrame([{ + 'videoid': idx, + 'contentUrl': 'x', + 'duration': 'x', + 'data_dir': args.dataset_name + f"/{view_name}", + 'instruction': instruction, + 'dynamic_confidence': 'x', + 'dynamic_wording': 'x', + 'dynamic_source_category': 'x', + 'embodiment': args.robot_name + }]) + ], + ignore_index=True) + + # Collect action and state + if v_idx == 0: + all_actions.append(actions) + all_states.append(states) + + # Create satas.safetensors + actions = torch.cat(all_actions, dim=0) + states = torch.cat(all_states, dim=0) + + stats = {'action': {}, 'observation.state': {}} + stats['action']['max'] = actions.max(dim=0).values + stats['action']['min'] = actions.min(dim=0).values + stats['action']['mean'] = actions.mean(dim=0) + stats['action']['std'] = actions.std(dim=0) + + stats['observation.state']['max'] = states.max(dim=0).values + stats['observation.state']['min'] = states.min(dim=0).values + stats['observation.state']['mean'] = states.mean(dim=0) + stats['observation.state']['std'] = states.std(dim=0) + + flattened_stats = flatten_dict(stats) + target_stats_file = target_meta_dir / "stats.safetensors" + save_file(flattened_stats, target_stats_file) + + df.to_csv(csv_file, index=False) + print(f">>> Finished create {args.dataset_name} dataset ...") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--source_dir', + action='store', + type=str, + help='The dataset dir under lerobot 2.0 data format.', + required=True) + parser.add_argument('--target_dir', + action='store', + type=str, + default='./data', + help='The target dir to save new formatted dataset.') + parser.add_argument('--dataset_name', + action='store', + type=str, + help='dataset name', + required=True) + parser.add_argument('--robot_name', + action='store', + type=str, + help='robot name', + required=True) + main(parser.parse_args()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100755 index 0000000..e08d9e6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,53 @@ +[project] +name = "unifolm_wma" +version = "0.0.1" +description = "UnifoLM-WMA-0" +license = { text = "BSD-3-Clause" } +authors = [ + {name="Unitree Embodied AI R&D Team", email="rd_xyc@unitree.com" } +] +requires-python = "==3.10.18" +dependencies = [ + "decord==0.6.0", + "einops==0.8.0", + "imageio==2.35.1", + "numpy==1.24.2", + "omegaconf==2.3.0", + "opencv-python==4.10.0.84", + "pandas==2.0.0", + "pillow==9.5.0", + "pytorch-lightning==1.9.3", + "pyyaml==6.0", + "setuptools==65.6.3", + "torch==2.3.1", + "torchvision==0.18.1", + "tqdm==4.66.5", + "transformers==4.40.1", + "moviepy==1.0.3", + "av==12.3.0", + "xformers==0.0.27", + "gradio==4.39.0", + "timm==0.9.10", + "scikit-learn==1.5.1", + "open-clip-torch==2.22.0", + "kornia==0.7.3", + "diffusers==0.30.2", + "termcolor==2.4.0", + "draccus==0.11.5", + "accelerate==1.7.0", + "tensorflow-metadata==1.16.1", + "protobuf==3.20.3", + "datasets==3.6.0", + "tensorflow-graphics==2021.12.3", + "fairscale==0.4.13" +] + +[build-system] +requires = ["setuptools>=65.6.3", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = { "" = "src" } + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/run_all_cases.sh b/run_all_cases.sh new file mode 100755 index 0000000..6252554 --- /dev/null +++ b/run_all_cases.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +# 自动执行所有场景的所有case +# 总共5个场景,每个场景4个case,共20个case +# 设置环境变量(离线模式) +export HF_HUB_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 + +# 颜色定义 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# 定义所有场景 +SCENARIOS=( + "unitree_g1_pack_camera" + "unitree_z1_dual_arm_cleanup_pencils" + "unitree_z1_dual_arm_stackbox" + "unitree_z1_dual_arm_stackbox_v2" + "unitree_z1_stackbox" +) + +# 定义case数量 +CASES=(1 2 3 4) + +# 记录开始时间 +START_TIME=$(date +%s) +LOG_FILE="run_all_cases_$(date +%Y%m%d_%H%M%S).log" + +echo -e "${BLUE}========================================${NC}" +echo -e "${BLUE}开始执行所有场景的case${NC}" +echo -e "${BLUE}总共: ${#SCENARIOS[@]} 个场景 x ${#CASES[@]} 个case = $((${#SCENARIOS[@]} * ${#CASES[@]})) 个任务${NC}" +echo -e "${BLUE}日志文件: ${LOG_FILE}${NC}" +echo -e "${BLUE}========================================${NC}" +echo "" + +# 初始化计数器 +TOTAL_CASES=$((${#SCENARIOS[@]} * ${#CASES[@]})) +CURRENT_CASE=0 +SUCCESS_COUNT=0 +FAIL_COUNT=0 + +# 记录失败的case +declare -a FAILED_CASES + +# 遍历所有场景 +for scenario in "${SCENARIOS[@]}"; do + echo -e "${YELLOW}>>> 场景: ${scenario}${NC}" + + # 遍历所有case + for case_num in "${CASES[@]}"; do + CURRENT_CASE=$((CURRENT_CASE + 1)) + case_dir="${scenario}/case${case_num}" + script_path="${case_dir}/run_world_model_interaction.sh" + + echo -e "${BLUE}[${CURRENT_CASE}/${TOTAL_CASES}] 执行: ${case_dir}${NC}" + + # 检查脚本是否存在 + if [ ! -f "${script_path}" ]; then + echo -e "${RED}错误: 脚本不存在 ${script_path}${NC}" + FAIL_COUNT=$((FAIL_COUNT + 1)) + FAILED_CASES+=("${case_dir} (脚本不存在)") + continue + fi + + # 执行脚本 + echo "开始时间: $(date '+%Y-%m-%d %H:%M:%S')" + + if bash "${script_path}" >> "${LOG_FILE}" 2>&1; then + echo -e "${GREEN}✓ 成功: ${case_dir}${NC}" + SUCCESS_COUNT=$((SUCCESS_COUNT + 1)) + else + echo -e "${RED}✗ 失败: ${case_dir}${NC}" + FAIL_COUNT=$((FAIL_COUNT + 1)) + FAILED_CASES+=("${case_dir}") + fi + + echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')" + echo "" + done + + echo "" +done + +# 计算总耗时 +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) +HOURS=$((DURATION / 3600)) +MINUTES=$(((DURATION % 3600) / 60)) +SECONDS=$((DURATION % 60)) + +# 输出总结 +echo -e "${BLUE}========================================${NC}" +echo -e "${BLUE}执行完成!${NC}" +echo -e "${BLUE}========================================${NC}" +echo -e "总任务数: ${TOTAL_CASES}" +echo -e "${GREEN}成功: ${SUCCESS_COUNT}${NC}" +echo -e "${RED}失败: ${FAIL_COUNT}${NC}" +echo -e "总耗时: ${HOURS}小时 ${MINUTES}分钟 ${SECONDS}秒" +echo -e "详细日志: ${LOG_FILE}" +echo "" + +# 如果有失败的case,列出来 +if [ ${FAIL_COUNT} -gt 0 ]; then + echo -e "${RED}失败的case列表:${NC}" + for failed_case in "${FAILED_CASES[@]}"; do + echo -e "${RED} - ${failed_case}${NC}" + done + echo "" +fi + +echo -e "${BLUE}========================================${NC}" diff --git a/scripts/evaluation/base_model_inference.py b/scripts/evaluation/base_model_inference.py new file mode 100644 index 0000000..42945a7 --- /dev/null +++ b/scripts/evaluation/base_model_inference.py @@ -0,0 +1,541 @@ +import argparse, os, glob +import datetime, time +import pandas as pd +import torch +import torchvision +import torchvision.transforms as transforms +import random + +from pytorch_lightning import seed_everything +from PIL import Image +from omegaconf import OmegaConf +from tqdm import tqdm +from einops import rearrange, repeat +from collections import OrderedDict + +from unifolm_wma.models.samplers.ddim import DDIMSampler +from unifolm_wma.utils.utils import instantiate_from_config + + +def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: + """ + Get list of files in `data_dir` with extensions in `postfixes`. + + Args: + data_dir (str): Directory path. + postfixes (list[str]): List of file extensions (e.g., ['csv', 'jpg']). + + Returns: + list[str]: Sorted list of matched file paths. + """ + patterns = [ + os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes + ] + file_list = [] + for pattern in patterns: + file_list.extend(glob.glob(pattern)) + file_list.sort() + return file_list + + +def load_model_checkpoint(model: torch.nn.Module, + ckpt: str) -> torch.nn.Module: + """ + Load model weights from checkpoint file. + + Args: + model (torch.nn.Module): The model to load weights into. + ckpt (str): Path to the checkpoint file. + + Returns: + torch.nn.Module: Model with weights loaded. + """ + state_dict = torch.load(ckpt, map_location="cpu") + if "state_dict" in list(state_dict.keys()): + state_dict = state_dict["state_dict"] + try: + loaded = model.load_state_dict(state_dict, strict=False) + print("Missing keys:") + for k in loaded.missing_keys: + print(f" {k}") + print("Unexpected keys:") + for k in loaded.unexpected_keys: + print(f" {k}") + + except: + # Rename the keys for 256x256 model + new_pl_sd = OrderedDict() + for k, v in state_dict.items(): + new_pl_sd[k] = v + + for k in list(new_pl_sd.keys()): + if "framestride_embed" in k: + new_key = k.replace("framestride_embed", "fps_embedding") + new_pl_sd[new_key] = new_pl_sd[k] + del new_pl_sd[k] + model.load_state_dict(new_pl_sd, strict=False) + else: + new_pl_sd = OrderedDict() + for key in state_dict['module'].keys(): + new_pl_sd[key[16:]] = state_dict['module'][key] + model.load_state_dict(new_pl_sd) + print('>>> model checkpoint loaded.') + return model + + +def load_prompts(prompt_file: str) -> list[str]: + """ + Load prompts from a text file, one per line. + + Args: + prompt_file (str): Path to the prompt file. + + Returns: + list[str]: List of prompt strings. + """ + f = open(prompt_file, 'r') + prompt_list = [] + for idx, line in enumerate(f.readlines()): + l = line.strip() + if len(l) != 0: + prompt_list.append(l) + f.close() + return prompt_list + + +def load_data_prompts( + data_dir: str, + savedir: str, + video_size: tuple[int, int] = (256, 256), + video_frames: int = 16 +) -> tuple[list[str], list[torch.Tensor], list[str], list[float], list[float], + list[int]]: + """ + Load image prompts, process them into video format, and retrieve metadata. + + Args: + data_dir (str): Directory containing images and CSV file. + savedir (str): Output directory to check if inference was already done. + video_size (tuple[int, int], optional): Target size of video frames. + video_frames (int, optional): Number of frames in each video. + + Returns: + tuple: (filenames, video tensors, prompts, fps values, fs values, num_generations) + """ + + transform = transforms.Compose([ + transforms.Resize(min(video_size)), + transforms.CenterCrop(video_size), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ]) + + # Load prompt csv + prompt_file = get_filelist(data_dir, ['csv']) + assert len(prompt_file) > 0, "Error: found NO image prompt file!" + + # Load image prompts + file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG']) + data_list = [] + filename_list = [] + prompt_list = [] + fps_list = [] + fs_list = [] + num_gen_list = [] + prompt_csv = pd.read_csv(prompt_file[0]) + n_samples = len(file_list) + + for idx in range(n_samples): + image = Image.open(file_list[idx]).convert('RGB') + image_tensor = transform(image).unsqueeze(1) + frame_tensor = repeat(image_tensor, + 'c t h w -> c (repeat t) h w', + repeat=video_frames) + _, filename = os.path.split(file_list[idx]) + + if not is_inferenced(savedir, filename): + video_id = filename[:-4] + prompt_csv['videoid'] = prompt_csv['videoid'].map(str) + if not (prompt_csv['videoid'] == video_id).any(): + continue + data_list.append(frame_tensor) + filename_list.append(filename) + ins = prompt_csv[prompt_csv['videoid'] == + video_id]['instruction'].values[0] + prompt_list.append(ins) + fps = prompt_csv[prompt_csv['videoid'] == + video_id]['fps'].values[0] + fps_list.append(fps) + fs = prompt_csv[prompt_csv['videoid'] == video_id]['fs'].values[0] + fs_list.append(fs) + num_gen = prompt_csv[prompt_csv['videoid'] == + video_id]['num_gen'].values[0] + num_gen_list.append(int(num_gen)) + + return filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list + + +def is_inferenced(save_dir: str, filename: str) -> bool: + """ + Check if a result video already exists. + + Args: + save_dir (str): Directory where results are saved. + filename (str): Base filename to check. + + Returns: + bool: True if file exists, else False. + """ + video_file = os.path.join(save_dir, f"{filename[:-4]}.mp4") + return os.path.exists(video_file) + + +def save_results_seperate(prompt: str | list[str], + samples: torch.Tensor, + filename: str, + fakedir: str, + fps: int = 8) -> None: + """ + Save generated video samples as .mp4 files. + + Args: + prompt (str | list[str]): The prompt text. + samples (torch.Tensor): Generated video tensor of shape [B, C, T, H, W]. + filename (str): Output filename. + fakedir (str): Directory to save output videos. + fps (int, optional): Frames per second. + + Returns: + None + """ + prompt = prompt[0] if isinstance(prompt, list) else prompt + + # Save video + videos = [samples] + savedirs = [fakedir] + for idx, video in enumerate(videos): + if video is None: + continue + video = video.detach().cpu() + video = torch.clamp(video.float(), -1., 1.) + n = video.shape[0] + for i in range(n): + grid = video[i, ...] + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) + path = os.path.join(savedirs[idx], f'{filename.split(".")[0]}.mp4') + torchvision.io.write_video(path, + grid, + fps=fps, + video_codec='h264', + options={'crf': '0'}) + + +def get_latent_z(model: torch.nn.Module, videos: torch.Tensor) -> torch.Tensor: + """ + Encode videos to latent space. + + Args: + model (torch.nn.Module): Model with encode_first_stage function. + videos (torch.Tensor): Video tensor of shape [B, C, T, H, W]. + + Returns: + torch.Tensor: Latent representation of shape [B, C, T, H, W]. + """ + b, c, t, h, w = videos.shape + x = rearrange(videos, 'b c t h w -> (b t) c h w') + z = model.encode_first_stage(x) + z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) + return z + + +def image_guided_synthesis(model: torch.nn.Module, + prompts: list[str], + videos: torch.Tensor, + noise_shape: list[int], + ddim_steps: int = 50, + ddim_eta: float = 1.0, + unconditional_guidance_scale: float = 1.0, + fs: int | None = None, + text_input: bool = False, + timestep_spacing: str = 'uniform', + guidance_rescale: float = 0.0, + **kwargs) -> torch.Tensor: + """ + Run DDIM-based image-to-video synthesis with hybrid/text+image guidance. + + Args: + model (torch.nn.Module): Diffusion model. + prompts (list[str]): Text prompts. + videos (torch.Tensor): Input images/videos of shape [B, C, T, H, W]. + noise_shape (list[int]): Latent noise shape [B, C, T, H, W]. + ddim_steps (int, optional): Number of DDIM steps. + ddim_eta (float, optional): Eta value for DDIM. + unconditional_guidance_scale (float, optional): Guidance scale. + fs (int | None, optional): FPS input for sampler. + text_input (bool, optional): If True, use text guidance. + timestep_spacing (str, optional): Timestep schedule spacing. + guidance_rescale (float, optional): Rescale guidance effect. + **kwargs: Additional sampler args. + + Returns: + torch.Tensor: Synthesized videos of shape [B, 1, C, T, H, W]. + """ + + ddim_sampler = DDIMSampler(model) + batch_size = noise_shape[0] + fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) + + if not text_input: + prompts = [""] * batch_size + + b, c, t, h, w = videos.shape + img = videos[:, :, 0] + img_emb = model.embedder(img) + img_emb = model.image_proj_model(img_emb) + img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t) + cond_emb = model.get_learned_conditioning(prompts) + cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0) + + cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]} + if model.model.conditioning_key == 'hybrid': + z = get_latent_z(model, videos) + img_cat_cond = z[:, :, :1, :, :] + img_cat_cond = repeat(img_cat_cond, + 'b c t h w -> b c (repeat t) h w', + repeat=z.shape[2]) + cond["c_concat"] = [img_cat_cond] + + uc = None + cond_mask = None + kwargs.update({"unconditional_conditioning_img_nonetext": None}) + + batch_variants = [] + if ddim_sampler is not None: + samples, _, _, _ = ddim_sampler.sample( + S=ddim_steps, + batch_size=batch_size, + shape=noise_shape[1:], + conditioning=cond, + eta=ddim_eta, + mask=cond_mask, + x0=None, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + fs=fs, + timestep_spacing=timestep_spacing, + guidance_rescale=guidance_rescale, + **kwargs) + + # Reconstruct from latent to pixel space + batch_images = model.decode_first_stage(samples) + batch_variants.append(batch_images) + + batch_variants = torch.stack(batch_variants) + return batch_variants.permute(1, 0, 2, 3, 4, 5) + + +def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: + """ + Run inference pipeline on prompts and image inputs. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + gpu_num (int): Number of GPUs. + gpu_no (int): Index of the current GPU. + + Returns: + None + """ + # Load config + config = OmegaConf.load(args.config) + # Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set" + config['model']['params']['wma_config']['params'][ + 'use_checkpoint'] = False + model = instantiate_from_config(config.model) + model = model.cuda(gpu_no) + model.perframe_ae = args.perframe_ae + assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" + model = load_model_checkpoint(model, args.ckpt_path) + model.eval() + + # Run over data + assert (args.height % 16 == 0) and ( + args.width % 16 + == 0), "Error: image size [h,w] should be multiples of 16!" + assert args.bs == 1, "Current implementation only support [batch size = 1]!" + + # Get latent noise shape + h, w = args.height // 8, args.width // 8 + channels = model.model.diffusion_model.out_channels + n_frames = args.video_length + print(f'>>> Generate {n_frames} frames under each generation ...') + noise_shape = [args.bs, channels, n_frames, h, w] + + fakedir = os.path.join(args.savedir, "samples") + os.makedirs(fakedir, exist_ok=True) + + # Prompt file setting + assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!" + filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list = load_data_prompts( + args.prompt_dir, + args.savedir, + video_size=(args.height, args.width), + video_frames=n_frames) + + num_samples = len(prompt_list) + samples_split = num_samples // gpu_num + print('>>> Prompts testing [rank:%d] %d/%d samples loaded.' % + (gpu_no, samples_split, num_samples)) + + indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1))) + fps_list_rank = [fps_list[i] for i in indices] + fs_list_rank = [fs_list[i] for i in indices] + prompt_list_rank = [prompt_list[i] for i in indices] + data_list_rank = [data_list[i] for i in indices] + filename_list_rank = [filename_list[i] for i in indices] + + with torch.no_grad(), torch.cuda.amp.autocast(): + # Create a new result csv + for idx, indice in enumerate( + tqdm(range(0, len(prompt_list_rank), args.bs), + desc=f'Sample batch')): + fps = fps_list_rank[indice:indice + args.bs] + fs = fs_list_rank[indice:indice + args.bs] + prompts = prompt_list_rank[indice:indice + args.bs] + num_gen = num_gen_list[indice:indice + args.bs] + videos = data_list_rank[indice:indice + args.bs] + filenames = filename_list_rank[indice:indice + args.bs] + if isinstance(videos, list): + videos = torch.stack(videos, dim=0).to("cuda") + else: + videos = videos.unsqueeze(0).to("cuda") + + results = [] + print( + f">>> {prompts[0]}, frame_stride:{fs[0]}, and {num_gen[0]} generation ..." + ) + for _ in range(num_gen[0]): + batch_samples = image_guided_synthesis( + model, prompts, videos, noise_shape, args.ddim_steps, + args.ddim_eta, args.unconditional_guidance_scale, + fps[0] // fs[0], args.text_input, args.timestep_spacing, + args.guidance_rescale) + results.extend(batch_samples) + videos = repeat(batch_samples[0][:, :, -1, :, :].unsqueeze(2), + 'b c t h w -> b c (repeat t) h w', + repeat=batch_samples[0].shape[2]) + batch_samples = [torch.concat(results, axis=2)] + + # Save each example individually + for nn, samples in enumerate(batch_samples): + prompt = prompts[nn] + filename = filenames[nn] + save_results_seperate(prompt, + samples, + filename, + fakedir, + fps=8) + + +def get_parser() -> argparse.ArgumentParser: + """ + Create and return the argument parser. + + Returns: + argparse.ArgumentParser: Parser for command-line arguments. + """ + parser = argparse.ArgumentParser() + + parser.add_argument("--savedir", + type=str, + default=None, + help="Path to save the results.") + parser.add_argument("--ckpt_path", + type=str, + default=None, + help="Path to the model checkpoint.") + parser.add_argument("--config", + type=str, + help="Path to the YAML configuration file.") + parser.add_argument( + "--prompt_dir", + type=str, + default=None, + help="Directory containing videos and corresponding prompts.") + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="Number of DDIM steps. If non-positive, DDPM is used instead.") + parser.add_argument( + "--ddim_eta", + type=float, + default=1.0, + help="Eta for DDIM sampling. Set to 0.0 for deterministic results.") + parser.add_argument("--bs", + type=int, + default=1, + help="Batch size for inference. Must be 1.") + parser.add_argument("--height", + type=int, + default=320, + help="Height of the generated images in pixels.") + parser.add_argument("--width", + type=int, + default=512, + help="Width of the generated images in pixels.") + parser.add_argument( + "--unconditional_guidance_scale", + type=float, + default=1.0, + help="Scale for classifier-free guidance during sampling.") + parser.add_argument("--seed", + type=int, + default=123, + help="Random seed for reproducibility.") + parser.add_argument("--video_length", + type=int, + default=16, + help="Number of frames in the generated video.") + parser.add_argument( + "--text_input", + action='store_true', + default=False, + help= + "Whether to provide a text prompt as input to the image-to-video model." + ) + parser.add_argument( + "--timestep_spacing", + type=str, + default="uniform", + help= + "Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." + ) + parser.add_argument( + "--guidance_rescale", + type=float, + default=0.0, + help= + "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." + ) + parser.add_argument( + "--perframe_ae", + action='store_true', + default=False, + help= + "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." + ) + return parser + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + seed = args.seed + if seed < 0: + seed = random.randint(0, 2**31) + seed_everything(seed) + rank, gpu_num = 0, 1 + run_inference(args, gpu_num, rank) diff --git a/scripts/evaluation/eval_utils.py b/scripts/evaluation/eval_utils.py new file mode 100644 index 0000000..366335a --- /dev/null +++ b/scripts/evaluation/eval_utils.py @@ -0,0 +1,77 @@ +import torch +import warnings +import torchvision +import sys +import pyarrow as pa +import logging + +from dataclasses import dataclass, field +from typing import Dict, Any, ClassVar, Deque, Mapping, Union +from datasets.features.features import register_feature +from torch.utils.tensorboard.writer import SummaryWriter + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + + +@dataclass +class VideoFrame: + """ + Provides a type for a dataset containing video frames. + + Example: + + ```python + data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}] + features = {"image": VideoFrame()} + Dataset.from_dict(data_dict, features=Features(features)) + ``` + """ + + pa_type: ClassVar[Any] = pa.struct({ + "path": pa.string(), + "timestamp": pa.float32() + }) + _type: str = field(default="VideoFrame", init=False, repr=False) + + def __call__(self): + return self.pa_type + + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + "'register_feature' is experimental and might be subject to breaking changes in the future.", + category=UserWarning, + ) + register_feature(VideoFrame, "VideoFrame") + + +def populate_queues( + queues: Dict[str, Deque[Any]], + batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]: + + for key in batch: + if key not in queues: + continue + if len(queues[key]) != queues[key].maxlen: + while len(queues[key]) != queues[key].maxlen: + queues[key].append(batch[key]) + else: + queues[key].append(batch[key]) + return queues + + +def log_to_tensorboard( + writer: SummaryWriter, + data: Union[torch.Tensor, Any], + tag: str, + fps: int = 10) -> None: + if isinstance(data, torch.Tensor) and data.dim() == 5: + video = data + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = grid.unsqueeze(dim=0) + writer.add_video(tag, grid, fps=fps) diff --git a/scripts/evaluation/real_eval_server.py b/scripts/evaluation/real_eval_server.py new file mode 100644 index 0000000..d780b5d --- /dev/null +++ b/scripts/evaluation/real_eval_server.py @@ -0,0 +1,463 @@ +import argparse, os, sys +import torch +import torchvision +import warnings +import imageio +import logging +import matplotlib.pyplot as plt +plt.switch_backend('agg') +import traceback +import uvicorn + +from omegaconf import OmegaConf +from einops import rearrange, repeat +from collections import OrderedDict +from pytorch_lightning import seed_everything +from torch import nn +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from typing import Any, Dict, Optional, Tuple, List +from datetime import datetime + +from unifolm_wma.utils.utils import instantiate_from_config +from unifolm_wma.models.samplers.ddim import DDIMSampler + + +def get_device_from_parameters(module: nn.Module) -> torch.device: + """Get a module's device by checking one of its parameters. + + Args: + module (nn.Module): PyTorch module. + + Returns: + torch.device: The device where the module's parameters are stored. + """ + return next(iter(module.parameters())).device + + +def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: + """Load model weights from checkpoint file. + + Args: + model (nn.Module): Model to load weights into. + ckpt (str): Path to checkpoint file. + + Returns: + nn.Module: Model with loaded weights. + """ + + state_dict = torch.load(ckpt, map_location="cpu") + if "state_dict" in list(state_dict.keys()): + state_dict = state_dict["state_dict"] + try: + model.load_state_dict(state_dict, strict=False) + except: + new_pl_sd = OrderedDict() + for k, v in state_dict.items(): + new_pl_sd[k] = v + + for k in list(new_pl_sd.keys()): + if "framestride_embed" in k: + new_key = k.replace("framestride_embed", "fps_embedding") + new_pl_sd[new_key] = new_pl_sd[k] + del new_pl_sd[k] + model.load_state_dict(new_pl_sd, strict=False) + else: + new_pl_sd = OrderedDict() + for key in state_dict['module'].keys(): + new_pl_sd[key[16:]] = state_dict['module'][key] + model.load_state_dict(new_pl_sd) + print('>>> model checkpoint loaded.') + return model + + +def write_video(video_path: str, stacked_frames: List[Any], fps: int) -> None: + """Write a video to disk using imageio. + + Args: + video_path (str): Path to save the video. + stacked_frames (List[Any]): Frames to write. + fps (int): Frames per second. + """ + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + "pkg_resources is deprecated as an API", + category=DeprecationWarning) + imageio.mimsave(video_path, stacked_frames, fps=fps) + + +def save_results(video: torch.Tensor, filename: str, fps: int = 8) -> None: + """Save a video tensor as an MP4 file. + + Args: + video (torch.Tensor): Video tensor of shape (B, C, T, H, W). + filename (str): Path to save video. + fps (int, optional): Frame rate. Defaults to 8. + + """ + video = video.detach().cpu() + video = torch.clamp(video.float(), -1., 1.) + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + torchvision.io.write_video(filename, + grid, + fps=fps, + video_codec='h264', + options={'crf': '10'}) + + +def get_latent_z(model: nn.Module, videos: torch.Tensor) -> torch.Tensor: + """Encode videos into latent space. + + Args: + model (nn.Module): Model with `encode_first_stage` method. + videos (torch.Tensor): Input videos (B, C, T, H, W). + + Returns: + torch.Tensor: Latent representation (B, C, T, H, W). + + """ + b, c, t, h, w = videos.shape + x = rearrange(videos, 'b c t h w -> (b t) c h w') + z = model.encode_first_stage(x) + z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) + return z + + +def image_guided_synthesis( + model: torch.nn.Module, + prompts: list[str], + observation: Dict[str, torch.Tensor], + noise_shape: tuple[int, int, int, int, int], + ddim_steps: int = 50, + ddim_eta: float = 1.0, + unconditional_guidance_scale: float = 1.0, + fs: int | None = None, + timestep_spacing: str = 'uniform', + guidance_rescale: float = 0.0, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference with DDIM sampling. + + Args: + model (nn.Module): Diffusion model. + prompts (Any): Conditioning text prompts. + observation (Dict[str, torch.Tensor]): Observation dictionary. + noise_shape (List[int]): Shape of noise tensor. + ddim_steps (int, optional): Number of DDIM steps. Defaults to 50. + ddim_eta (float, optional): Sampling eta. Defaults to 1.0. + unconditional_guidance_scale (float, optional): Guidance scale. Defaults to 1.0. + fs (Optional[int], optional): Frame stride or FPS. Defaults to None. + timestep_spacing (str, optional): Spacing strategy. Defaults to "uniform". + guidance_rescale (float, optional): Guidance rescale. Defaults to 0.0. + **kwargs (Any): Additional arguments. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + + b, _, t, _, _ = noise_shape + ddim_sampler = DDIMSampler(model) + batch_size = noise_shape[0] + fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) + + img = observation['observation.images.top'] + cond_img = img[:, -1, ...] + cond_img_emb = model.embedder(cond_img) + cond_img_emb = model.image_proj_model(cond_img_emb) + + if model.model.conditioning_key == 'hybrid': + z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) + img_cat_cond = z[:, :, -1:, :, :] + img_cat_cond = repeat(img_cat_cond, + 'b c t h w -> b c (repeat t) h w', + repeat=noise_shape[2]) + cond = {"c_concat": [img_cat_cond]} + + cond_ins_emb = model.get_learned_conditioning(prompts) + cond_state = model.state_projector(observation['observation.state']) + cond_state_emb = model.agent_state_pos_emb + cond_state + + cond_action = model.action_projector(observation['action']) + cond_action_emb = model.agent_action_pos_emb + cond_action + cond_action_emb = torch.zeros_like(cond_action_emb) + + cond["c_crossattn"] = [ + torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1) + ] + cond["c_crossattn_action"] = [ + observation['observation.images.top'].permute( + 0, 2, 1, 3, 4)[:, :, -model.n_obs_steps_acting:], + observation['observation.state'][:, -model.n_obs_steps_acting:] + ] + + uc = None + + kwargs.update({"unconditional_conditioning_img_nonetext": None}) + + cond_mask = None + cond_z0 = None + + if ddim_sampler is not None: + + samples, actions, states, intermedia = ddim_sampler.sample( + S=ddim_steps, + conditioning=cond, + batch_size=batch_size, + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + cfg_img=None, + mask=cond_mask, + x0=cond_z0, + fs=fs, + timestep_spacing=timestep_spacing, + guidance_rescale=guidance_rescale, + **kwargs) + + # Reconstruct from latent to pixel space + batch_images = model.decode_first_stage(samples) + batch_variants = batch_images + + return batch_variants, actions, states + + +def run_inference(args: argparse.Namespace, gpu_num: int, + gpu_no: int) -> Tuple[nn.Module, List[int], Any]: + """ + Run inference pipeline on prompts and image inputs. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + gpu_num (int): Number of GPUs. + gpu_no (int): Index of the current GPU. + + Returns: + None + """ + # Load config + config = OmegaConf.load(args.config) + # Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set" + config['model']['params']['wma_config']['params']['use_checkpoint'] = False + model = instantiate_from_config(config.model) + model.perframe_ae = args.perframe_ae + assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" + model = load_model_checkpoint(model, args.ckpt_path) + model = model.cuda(gpu_no) + model.eval() + print(">>> Model is successfully loaded ...") + + # Build unnomalizer + logging.info("***** Configing Data *****") + data = instantiate_from_config(config.data) + data.setup() + print(">>> Dataset is successfully loaded ...") + + ## Run over data + assert (args.height % 16 == 0) and ( + args.width % 16 + == 0), "Error: image size [h,w] should be multiples of 16!" + assert args.bs == 1, "Current implementation only support [batch size = 1]!" + + ## Get latent noise shape + h, w = args.height // 8, args.width // 8 + channels = model.model.diffusion_model.out_channels + n_frames = args.video_length + print(f'>>> Generate {n_frames} frames under each generation ...') + noise_shape = [args.bs, channels, n_frames, h, w] + + return model, noise_shape, data + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--savedir", + type=str, + default=None, + help="Path to save the results.") + parser.add_argument("--ckpt_path", + type=str, + default=None, + help="Path to the model checkpoint.") + parser.add_argument("--config", type=str, help="Path to the config file.") + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="Number of DDIM steps. If non-positive, DDPM is used instead.") + parser.add_argument( + "--ddim_eta", + type=float, + default=1.0, + help="Eta for DDIM sampling. Set to 0.0 for deterministic results.") + parser.add_argument("--bs", + type=int, + default=1, + help="Batch size for inference. Must be 1.") + parser.add_argument("--height", + type=int, + default=320, + help="Height of the generated images in pixels.") + parser.add_argument("--width", + type=int, + default=512, + help="Width of the generated images in pixels.") + parser.add_argument( + "--frame_stride", + type=int, + default=3, + help= + "frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)" + ) + parser.add_argument( + "--unconditional_guidance_scale", + type=float, + default=1.0, + help="Scale for classifier-free guidance during sampling.") + parser.add_argument("--seed", + type=int, + default=123, + help="Random seed for reproducibility.") + parser.add_argument("--video_length", + type=int, + default=16, + help="Number of frames in the generated video.") + parser.add_argument( + "--timestep_spacing", + type=str, + default="uniform", + help= + "Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." + ) + parser.add_argument( + "--guidance_rescale", + type=float, + default=0.0, + help= + "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." + ) + parser.add_argument( + "--perframe_ae", + action='store_true', + default=False, + help= + "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." + ) + return parser + + +class Server: + + def __init__(self, args: argparse.Namespace) -> None: + self.model_, self.noise_shape_, self.data_ = run_inference(args, 1, 0) + self.args_ = args + self.dataset_name = self.data_.dataset_configs['test']['params'][ + 'dataset_name'] + self.device_ = get_device_from_parameters(self.model_) + + def normalize_image(self, image: torch.Tensor) -> torch.Tensor: + return (image / 255 - 0.5) * 2 + + def predict_action(self, payload: Dict[str, Any]) -> Any: + try: + images = payload['observation.images.top'] + states = payload['observation.state'] + actions = payload['action'] # Should be all zeros + language_instruction = payload['language_instruction'] + + images = torch.tensor(images).cuda() + images = self.data_.test_datasets[ + self.dataset_name].spatial_transform(images).unsqueeze(0) + images = self.normalize_image(images) + print(f"images shape: {images.shape} ...") + states = torch.tensor(states) + states = self.data_.test_datasets[self.dataset_name].normalizer( + {'observation.state': states})['observation.state'] + states, _ = self.data_.test_datasets[ + self.dataset_name]._map_to_uni_state(states, "joint position") + print(f"states shape: {states.shape} ...") + actions = torch.tensor(actions) + actions, action_mask = self.data_.test_datasets[ + self.dataset_name]._map_to_uni_action(actions, + "joint position") + print(f"actions shape: {actions.shape} ...") + print("=" * 20) + states = states.unsqueeze(0).cuda() + actions = actions.unsqueeze(0).cuda() + + observation = { + 'observation.images.top': images, + 'observation.state': states, + 'action': actions + } + observation = { + key: observation[key].to(self.device_, non_blocking=True) + for key in observation + } + + args = self.args_ + pred_videos, pred_action, _ = image_guided_synthesis( + self.model_, + language_instruction, + observation, + self.noise_shape_, + ddim_steps=args.ddim_steps, + ddim_ets=args.ddim_eta, + unconditional_guidance_scale=args.unconditional_guidance_scale, + fs=30 / args.frame_stride, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale) + + pred_action = pred_action[..., action_mask[0] == 1.0][0].cpu() + pred_action = self.data_.test_datasets[ + self.dataset_name].unnormalizer({'action': + pred_action})['action'] + + os.makedirs(args.savedir, exist_ok=True) + current_time = datetime.now().strftime("%H:%M:%S") + video_file = f'{args.savedir}/{current_time}.mp4' + save_results(pred_videos.cpu(), video_file) + + response = { + 'result': 'ok', + 'action': pred_action.tolist(), + 'desc': 'success' + } + return JSONResponse(response) + + except: + logging.error(traceback.format_exc()) + logging.warning( + "Your request threw an error; make sure your request complies with the expected format:\n" + "{'image': np.ndarray, 'instruction': str}\n" + "You can optionally an `unnorm_key: str` to specific the dataset statistics you want to use for " + "de-normalizing the output actions.") + return {'result': 'error', 'desc': traceback.format_exc()} + + def run(self, host: str = "127.0.0.1", port: int = 8000) -> None: + self.app = FastAPI() + self.app.post("/predict_action")(self.predict_action) + print(">>> Inference server is ready ... ") + uvicorn.run(self.app, host=host, port=port) + print(">>> Inference server stops ... ") + return + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + seed = args.seed + seed_everything(seed) + rank, gpu_num = 0, 1 + print(">>> Launch inference server ... ") + server = Server(args) + server.run() diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py new file mode 100644 index 0000000..c87bae5 --- /dev/null +++ b/scripts/evaluation/world_model_interaction.py @@ -0,0 +1,1220 @@ +import argparse, os, glob +import pandas as pd +import random +import torch +import torchvision +import h5py +import numpy as np +import logging +import einops +import warnings +import imageio +import time +import json +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass, field, asdict +from typing import Optional, Dict, List, Any + +from pytorch_lightning import seed_everything +from omegaconf import OmegaConf +from tqdm import tqdm +from einops import rearrange, repeat +from collections import OrderedDict +from torch import nn +from eval_utils import populate_queues, log_to_tensorboard +from collections import deque +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter +from PIL import Image + +from unifolm_wma.models.samplers.ddim import DDIMSampler +from unifolm_wma.utils.utils import instantiate_from_config + + +# ========== Profiling Infrastructure ========== +@dataclass +class TimingRecord: + """Record for a single timing measurement.""" + name: str + start_time: float = 0.0 + end_time: float = 0.0 + cuda_time_ms: float = 0.0 + count: int = 0 + children: List['TimingRecord'] = field(default_factory=list) + + @property + def cpu_time_ms(self) -> float: + return (self.end_time - self.start_time) * 1000 + + def to_dict(self) -> dict: + return { + 'name': self.name, + 'cpu_time_ms': self.cpu_time_ms, + 'cuda_time_ms': self.cuda_time_ms, + 'count': self.count, + 'children': [c.to_dict() for c in self.children] + } + + +class ProfilerManager: + """Manages macro and micro-level profiling.""" + + def __init__(self, enabled: bool = False, output_dir: str = "./profile_output"): + self.enabled = enabled + self.output_dir = output_dir + self.macro_timings: Dict[str, List[float]] = {} + self.cuda_events: Dict[str, List[tuple]] = {} + self.memory_snapshots: List[Dict] = [] + self.pytorch_profiler = None + self.current_iteration = 0 + self.operator_stats: Dict[str, Dict] = {} + + if enabled: + os.makedirs(output_dir, exist_ok=True) + + @contextmanager + def profile_section(self, name: str, sync_cuda: bool = True): + """Context manager for profiling a code section.""" + if not self.enabled: + yield + return + + if sync_cuda and torch.cuda.is_available(): + torch.cuda.synchronize() + + start_event = None + end_event = None + if torch.cuda.is_available(): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + start_time = time.perf_counter() + + try: + yield + finally: + if sync_cuda and torch.cuda.is_available(): + torch.cuda.synchronize() + + end_time = time.perf_counter() + cpu_time_ms = (end_time - start_time) * 1000 + + cuda_time_ms = 0.0 + if start_event is not None and end_event is not None: + end_event.record() + torch.cuda.synchronize() + cuda_time_ms = start_event.elapsed_time(end_event) + + if name not in self.macro_timings: + self.macro_timings[name] = [] + self.macro_timings[name].append(cpu_time_ms) + + if name not in self.cuda_events: + self.cuda_events[name] = [] + self.cuda_events[name].append((cpu_time_ms, cuda_time_ms)) + + def record_memory(self, tag: str = ""): + """Record current GPU memory state.""" + if not self.enabled or not torch.cuda.is_available(): + return + + snapshot = { + 'tag': tag, + 'iteration': self.current_iteration, + 'allocated_mb': torch.cuda.memory_allocated() / 1024**2, + 'reserved_mb': torch.cuda.memory_reserved() / 1024**2, + 'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2, + } + self.memory_snapshots.append(snapshot) + + def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3): + """Start PyTorch profiler for operator-level analysis.""" + if not self.enabled: + return nullcontext() + + self.pytorch_profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=wait, warmup=warmup, active=active, repeat=1 + ), + on_trace_ready=self._trace_handler, + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + with_modules=True, + ) + return self.pytorch_profiler + + def _trace_handler(self, prof): + """Handle profiler trace output.""" + trace_path = os.path.join( + self.output_dir, + f"trace_iter_{self.current_iteration}.json" + ) + prof.export_chrome_trace(trace_path) + + # Extract operator statistics + key_averages = prof.key_averages(group_by_input_shape=True) + for evt in key_averages: + op_name = evt.key + if op_name not in self.operator_stats: + self.operator_stats[op_name] = { + 'count': 0, + 'cpu_time_total_us': 0, + 'cuda_time_total_us': 0, + 'self_cpu_time_total_us': 0, + 'self_cuda_time_total_us': 0, + 'cpu_memory_usage': 0, + 'cuda_memory_usage': 0, + 'flops': 0, + } + stats = self.operator_stats[op_name] + stats['count'] += evt.count + stats['cpu_time_total_us'] += evt.cpu_time_total + stats['cuda_time_total_us'] += evt.cuda_time_total + stats['self_cpu_time_total_us'] += evt.self_cpu_time_total + stats['self_cuda_time_total_us'] += evt.self_cuda_time_total + if hasattr(evt, 'cpu_memory_usage'): + stats['cpu_memory_usage'] += evt.cpu_memory_usage + if hasattr(evt, 'cuda_memory_usage'): + stats['cuda_memory_usage'] += evt.cuda_memory_usage + if hasattr(evt, 'flops') and evt.flops: + stats['flops'] += evt.flops + + def step_profiler(self): + """Step the PyTorch profiler.""" + if self.pytorch_profiler is not None: + self.pytorch_profiler.step() + + def generate_report(self) -> str: + """Generate comprehensive profiling report.""" + if not self.enabled: + return "Profiling disabled." + + report_lines = [] + report_lines.append("=" * 80) + report_lines.append("PERFORMANCE PROFILING REPORT") + report_lines.append("=" * 80) + report_lines.append("") + + # Macro-level timing summary + report_lines.append("-" * 40) + report_lines.append("MACRO-LEVEL TIMING SUMMARY") + report_lines.append("-" * 40) + report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}") + report_lines.append("-" * 86) + + total_time = 0 + timing_data = [] + for name, times in sorted(self.macro_timings.items()): + cuda_times = [ct for _, ct in self.cuda_events.get(name, [])] + avg_time = np.mean(times) + avg_cuda = np.mean(cuda_times) if cuda_times else 0 + total = sum(times) + total_time += total + timing_data.append({ + 'name': name, + 'count': len(times), + 'total_ms': total, + 'avg_ms': avg_time, + 'cuda_avg_ms': avg_cuda, + 'times': times, + 'cuda_times': cuda_times, + }) + report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}") + + report_lines.append("-" * 86) + report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}") + report_lines.append("") + + # Memory summary + if self.memory_snapshots: + report_lines.append("-" * 40) + report_lines.append("GPU MEMORY SUMMARY") + report_lines.append("-" * 40) + max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots) + avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots]) + report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB") + report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB") + report_lines.append("") + + # Top operators by CUDA time + if self.operator_stats: + report_lines.append("-" * 40) + report_lines.append("TOP 30 OPERATORS BY CUDA TIME") + report_lines.append("-" * 40) + sorted_ops = sorted( + self.operator_stats.items(), + key=lambda x: x[1]['cuda_time_total_us'], + reverse=True + )[:30] + + report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}") + report_lines.append("-" * 96) + + for op_name, stats in sorted_ops: + # Truncate long operator names + display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name + report_lines.append( + f"{display_name:<50} {stats['count']:>8} " + f"{stats['cuda_time_total_us']/1000:>12.2f} " + f"{stats['cpu_time_total_us']/1000:>12.2f} " + f"{stats['self_cuda_time_total_us']/1000:>14.2f}" + ) + report_lines.append("") + + # Compute category breakdown + report_lines.append("-" * 40) + report_lines.append("OPERATOR CATEGORY BREAKDOWN") + report_lines.append("-" * 40) + + categories = { + 'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'], + 'Convolution': ['conv', 'cudnn'], + 'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'], + 'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'], + 'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'], + 'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'], + 'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'], + } + + category_times = {cat: 0.0 for cat in categories} + category_times['Other'] = 0.0 + + for op_name, stats in self.operator_stats.items(): + op_lower = op_name.lower() + categorized = False + for cat, keywords in categories.items(): + if any(kw in op_lower for kw in keywords): + category_times[cat] += stats['cuda_time_total_us'] + categorized = True + break + if not categorized: + category_times['Other'] += stats['cuda_time_total_us'] + + total_op_time = sum(category_times.values()) + report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}") + report_lines.append("-" * 57) + for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]): + pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0 + report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%") + report_lines.append("") + + report = "\n".join(report_lines) + return report + + def save_results(self): + """Save all profiling results to files.""" + if not self.enabled: + return + + # Save report + report = self.generate_report() + report_path = os.path.join(self.output_dir, "profiling_report.txt") + with open(report_path, 'w') as f: + f.write(report) + print(f">>> Profiling report saved to: {report_path}") + + # Save detailed JSON data + data = { + 'macro_timings': { + name: { + 'times': times, + 'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])] + } + for name, times in self.macro_timings.items() + }, + 'memory_snapshots': self.memory_snapshots, + 'operator_stats': self.operator_stats, + } + json_path = os.path.join(self.output_dir, "profiling_data.json") + with open(json_path, 'w') as f: + json.dump(data, f, indent=2) + print(f">>> Detailed profiling data saved to: {json_path}") + + # Print summary to console + print("\n" + report) + + +# Global profiler instance +_profiler: Optional[ProfilerManager] = None + +def get_profiler() -> ProfilerManager: + """Get the global profiler instance.""" + global _profiler + if _profiler is None: + _profiler = ProfilerManager(enabled=False) + return _profiler + +def init_profiler(enabled: bool, output_dir: str) -> ProfilerManager: + """Initialize the global profiler.""" + global _profiler + _profiler = ProfilerManager(enabled=enabled, output_dir=output_dir) + return _profiler + + +# ========== Original Functions ========== +def get_device_from_parameters(module: nn.Module) -> torch.device: + """Get a module's device by checking one of its parameters. + + Args: + module (nn.Module): The model whose device is to be inferred. + + Returns: + torch.device: The device of the model's parameters. + """ + return next(iter(module.parameters())).device + + +def write_video(video_path: str, stacked_frames: list, fps: int) -> None: + """Save a list of frames to a video file. + + Args: + video_path (str): Output path for the video. + stacked_frames (list): List of image frames. + fps (int): Frames per second for the video. + """ + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + "pkg_resources is deprecated as an API", + category=DeprecationWarning) + imageio.mimsave(video_path, stacked_frames, fps=fps) + + +def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: + """Return sorted list of files in a directory matching specified postfixes. + + Args: + data_dir (str): Directory path to search in. + postfixes (list[str]): List of file extensions to match. + + Returns: + list[str]: Sorted list of file paths. + """ + patterns = [ + os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes + ] + file_list = [] + for pattern in patterns: + file_list.extend(glob.glob(pattern)) + file_list.sort() + return file_list + + +def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: + """Load model weights from checkpoint file. + + Args: + model (nn.Module): Model instance. + ckpt (str): Path to the checkpoint file. + + Returns: + nn.Module: Model with loaded weights. + """ + state_dict = torch.load(ckpt, map_location="cpu") + if "state_dict" in list(state_dict.keys()): + state_dict = state_dict["state_dict"] + try: + model.load_state_dict(state_dict, strict=True) + except: + new_pl_sd = OrderedDict() + for k, v in state_dict.items(): + new_pl_sd[k] = v + + for k in list(new_pl_sd.keys()): + if "framestride_embed" in k: + new_key = k.replace("framestride_embed", "fps_embedding") + new_pl_sd[new_key] = new_pl_sd[k] + del new_pl_sd[k] + model.load_state_dict(new_pl_sd, strict=True) + else: + new_pl_sd = OrderedDict() + for key in state_dict['module'].keys(): + new_pl_sd[key[16:]] = state_dict['module'][key] + model.load_state_dict(new_pl_sd) + print('>>> model checkpoint loaded.') + return model + + +def is_inferenced(save_dir: str, filename: str) -> bool: + """Check if a given filename has already been processed and saved. + + Args: + save_dir (str): Directory where results are saved. + filename (str): Name of the file to check. + + Returns: + bool: True if processed file exists, False otherwise. + """ + video_file = os.path.join(save_dir, "samples_separate", + f"{filename[:-4]}_sample0.mp4") + return os.path.exists(video_file) + + +def save_results(video: Tensor, filename: str, fps: int = 8) -> None: + """Save video tensor to file using torchvision. + + Args: + video (Tensor): Tensor of shape (B, C, T, H, W). + filename (str): Output file path. + fps (int, optional): Frames per second. Defaults to 8. + """ + video = video.detach().cpu() + video = torch.clamp(video.float(), -1., 1.) + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + torchvision.io.write_video(filename, + grid, + fps=fps, + video_codec='h264', + options={'crf': '10'}) + + +def get_init_frame_path(data_dir: str, sample: dict) -> str: + """Construct the init_frame path from directory and sample metadata. + + Args: + data_dir (str): Base directory containing videos. + sample (dict): Dictionary containing 'data_dir' and 'videoid'. + + Returns: + str: Full path to the video file. + """ + rel_video_fp = os.path.join(sample['data_dir'], + str(sample['videoid']) + '.png') + full_image_fp = os.path.join(data_dir, 'images', rel_video_fp) + return full_image_fp + + +def get_transition_path(data_dir: str, sample: dict) -> str: + """Construct the full transition file path from directory and sample metadata. + + Args: + data_dir (str): Base directory containing transition files. + sample (dict): Dictionary containing 'data_dir' and 'videoid'. + + Returns: + str: Full path to the HDF5 transition file. + """ + rel_transition_fp = os.path.join(sample['data_dir'], + str(sample['videoid']) + '.h5') + full_transition_fp = os.path.join(data_dir, 'transitions', + rel_transition_fp) + return full_transition_fp + + +def prepare_init_input(start_idx: int, + init_frame_path: str, + transition_dict: dict[str, torch.Tensor], + frame_stride: int, + wma_data, + video_length: int = 16, + n_obs_steps: int = 2) -> dict[str, Tensor]: + """ + Extracts a structured sample from a video sequence including frames, states, and actions, + along with properly padded observations and pre-processed tensors for model input. + + Args: + start_idx (int): Starting frame index for the current clip. + video: decord video instance. + transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action', + 'observation.state', 'action_type', 'state_type'. + frame_stride (int): Temporal stride between sampled frames. + wma_data: Object that holds configuration and utility functions like normalization, + transformation, and resolution info. + video_length (int, optional): Number of frames to sample from the video. Default is 16. + n_obs_steps (int, optional): Number of historical steps for observations. Default is 2. + """ + + indices = [start_idx + frame_stride * i for i in range(video_length)] + init_frame = Image.open(init_frame_path).convert('RGB') + init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute( + 3, 0, 1, 2).float() + + if start_idx < n_obs_steps - 1: + state_indices = list(range(0, start_idx + 1)) + states = transition_dict['observation.state'][state_indices, :] + num_padding = n_obs_steps - 1 - start_idx + first_slice = states[0:1, :] # (t, d) + padding = first_slice.repeat(num_padding, 1) + states = torch.cat((padding, states), dim=0) + else: + state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1)) + states = transition_dict['observation.state'][state_indices, :] + + actions = transition_dict['action'][indices, :] + + ori_state_dim = states.shape[-1] + ori_action_dim = actions.shape[-1] + + frames_action_state_dict = { + 'action': actions, + 'observation.state': states, + } + frames_action_state_dict = wma_data.normalizer(frames_action_state_dict) + frames_action_state_dict = wma_data.get_uni_vec( + frames_action_state_dict, + transition_dict['action_type'], + transition_dict['state_type'], + ) + + if wma_data.spatial_transform is not None: + init_frame = wma_data.spatial_transform(init_frame) + init_frame = (init_frame / 255 - 0.5) * 2 + + data = { + 'observation.image': init_frame, + } + data.update(frames_action_state_dict) + return data, ori_state_dim, ori_action_dim + + +def get_latent_z(model, videos: Tensor) -> Tensor: + """ + Extracts latent features from a video batch using the model's first-stage encoder. + + Args: + model: the world model. + videos (Tensor): Input videos of shape [B, C, T, H, W]. + + Returns: + Tensor: Latent video tensor of shape [B, C, T, H, W]. + """ + profiler = get_profiler() + with profiler.profile_section("get_latent_z/encode"): + b, c, t, h, w = videos.shape + x = rearrange(videos, 'b c t h w -> (b t) c h w') + z = model.encode_first_stage(x) + z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) + return z + + +def preprocess_observation( + model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]: + """Convert environment observation to LeRobot format observation. + Args: + observation: Dictionary of observation batches from a Gym vector environment. + Returns: + Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. + """ + # Map to expected inputs for the policy + return_observations = {} + + if isinstance(observations["pixels"], dict): + imgs = { + f"observation.images.{key}": img + for key, img in observations["pixels"].items() + } + else: + imgs = {"observation.images.top": observations["pixels"]} + + for imgkey, img in imgs.items(): + img = torch.from_numpy(img) + + # Sanity check that images are channel last + _, h, w, c = img.shape + assert c < h and c < w, f"expect channel first images, but instead {img.shape}" + + # Sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + + # Convert to channel first of type float32 in range [0,1] + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() + img = img.type(torch.float32) + + return_observations[imgkey] = img + + return_observations["observation.state"] = torch.from_numpy( + observations["agent_pos"]).float() + return_observations['observation.state'] = model.normalize_inputs({ + 'observation.state': + return_observations['observation.state'].to(model.device) + })['observation.state'] + + return return_observations + + +def image_guided_synthesis_sim_mode( + model: torch.nn.Module, + prompts: list[str], + observation: dict, + noise_shape: tuple[int, int, int, int, int], + action_cond_step: int = 16, + n_samples: int = 1, + ddim_steps: int = 50, + ddim_eta: float = 1.0, + unconditional_guidance_scale: float = 1.0, + fs: int | None = None, + text_input: bool = True, + timestep_spacing: str = 'uniform', + guidance_rescale: float = 0.0, + sim_mode: bool = True, + **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text). + + Args: + model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning. + prompts (list[str]): A list of textual prompts to guide the synthesis process. + observation (dict): A dictionary containing observed inputs including: + - 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images) + - 'observation.state': Tensor of shape [B, O, D] (state vector) + - 'action': Tensor of shape [B, T, D] (action sequence) + noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate, + typically (B, C, T, H, W). + action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16. + n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1. + ddim_steps (int): Number of DDIM sampling steps. Default is 50. + ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0. + unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off. + fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None. + text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True. + timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace". + guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. + sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. + **kwargs: Additional arguments passed to the DDIM sampler. + + Returns: + batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W]. + actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding. + states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding. + """ + profiler = get_profiler() + + b, _, t, _, _ = noise_shape + ddim_sampler = DDIMSampler(model) + batch_size = noise_shape[0] + + fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) + + with profiler.profile_section("synthesis/conditioning_prep"): + img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) + cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] + cond_img_emb = model.embedder(cond_img) + cond_img_emb = model.image_proj_model(cond_img_emb) + + if model.model.conditioning_key == 'hybrid': + z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) + img_cat_cond = z[:, :, -1:, :, :] + img_cat_cond = repeat(img_cat_cond, + 'b c t h w -> b c (repeat t) h w', + repeat=noise_shape[2]) + cond = {"c_concat": [img_cat_cond]} + + if not text_input: + prompts = [""] * batch_size + cond_ins_emb = model.get_learned_conditioning(prompts) + + cond_state_emb = model.state_projector(observation['observation.state']) + cond_state_emb = cond_state_emb + model.agent_state_pos_emb + + cond_action_emb = model.action_projector(observation['action']) + cond_action_emb = cond_action_emb + model.agent_action_pos_emb + + if not sim_mode: + cond_action_emb = torch.zeros_like(cond_action_emb) + + cond["c_crossattn"] = [ + torch.cat( + [cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], + dim=1) + ] + cond["c_crossattn_action"] = [ + observation['observation.images.top'][:, :, + -model.n_obs_steps_acting:], + observation['observation.state'][:, -model.n_obs_steps_acting:], + sim_mode, + False, + ] + + uc = None + kwargs.update({"unconditional_conditioning_img_nonetext": None}) + cond_mask = None + cond_z0 = None + + if ddim_sampler is not None: + with profiler.profile_section("synthesis/ddim_sampling"): + samples, actions, states, intermedia = ddim_sampler.sample( + S=ddim_steps, + conditioning=cond, + batch_size=batch_size, + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + cfg_img=None, + mask=cond_mask, + x0=cond_z0, + fs=fs, + timestep_spacing=timestep_spacing, + guidance_rescale=guidance_rescale, + **kwargs) + + # Reconstruct from latent to pixel space + with profiler.profile_section("synthesis/decode_first_stage"): + batch_images = model.decode_first_stage(samples) + batch_variants = batch_images + + return batch_variants, actions, states + + +def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: + """ + Run inference pipeline on prompts and image inputs. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + gpu_num (int): Number of GPUs. + gpu_no (int): Index of the current GPU. + + Returns: + None + """ + profiler = get_profiler() + + # Create inference and tensorboard dirs + os.makedirs(args.savedir + '/inference', exist_ok=True) + log_dir = args.savedir + f"/tensorboard" + os.makedirs(log_dir, exist_ok=True) + writer = SummaryWriter(log_dir=log_dir) + + # Load prompt + csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") + df = pd.read_csv(csv_path) + + # Load config + with profiler.profile_section("model_loading/config"): + config = OmegaConf.load(args.config) + config['model']['params']['wma_config']['params'][ + 'use_checkpoint'] = False + model = instantiate_from_config(config.model) + model.perframe_ae = args.perframe_ae + + assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" + + with profiler.profile_section("model_loading/checkpoint"): + model = load_model_checkpoint(model, args.ckpt_path) + model.eval() + print(f'>>> Load pre-trained model ...') + + # Build unnomalizer + logging.info("***** Configing Data *****") + with profiler.profile_section("data_loading"): + data = instantiate_from_config(config.data) + data.setup() + print(">>> Dataset is successfully loaded ...") + + with profiler.profile_section("model_to_cuda"): + model = model.cuda(gpu_no) + device = get_device_from_parameters(model) + + profiler.record_memory("after_model_load") + + # Run over data + assert (args.height % 16 == 0) and ( + args.width % 16 + == 0), "Error: image size [h,w] should be multiples of 16!" + assert args.bs == 1, "Current implementation only support [batch size = 1]!" + + # Get latent noise shape + h, w = args.height // 8, args.width // 8 + channels = model.model.diffusion_model.out_channels + n_frames = args.video_length + print(f'>>> Generate {n_frames} frames under each generation ...') + noise_shape = [args.bs, channels, n_frames, h, w] + + # Determine profiler iterations + profile_active_iters = getattr(args, 'profile_iterations', 3) + use_pytorch_profiler = profiler.enabled and profile_active_iters > 0 + + # Start inference + for idx in range(0, len(df)): + sample = df.iloc[idx] + + # Got initial frame path + init_frame_path = get_init_frame_path(args.prompt_dir, sample) + ori_fps = float(sample['fps']) + + video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}" + os.makedirs(video_save_dir, exist_ok=True) + os.makedirs(video_save_dir + '/dm', exist_ok=True) + os.makedirs(video_save_dir + '/wm', exist_ok=True) + + # Load transitions to get the initial state later + transition_path = get_transition_path(args.prompt_dir, sample) + with profiler.profile_section("load_transitions"): + with h5py.File(transition_path, 'r') as h5f: + transition_dict = {} + for key in h5f.keys(): + transition_dict[key] = torch.tensor(h5f[key][()]) + for key in h5f.attrs.keys(): + transition_dict[key] = h5f.attrs[key] + + # If many, test various frequence control and world-model generation + for fs in args.frame_stride: + + # For saving imagens in policy + sample_save_dir = f'{video_save_dir}/dm/{fs}' + os.makedirs(sample_save_dir, exist_ok=True) + # For saving environmental changes in world-model + sample_save_dir = f'{video_save_dir}/wm/{fs}' + os.makedirs(sample_save_dir, exist_ok=True) + # For collecting interaction videos + wm_video = [] + # Initialize observation queues + cond_obs_queues = { + "observation.images.top": + deque(maxlen=model.n_obs_steps_imagen), + "observation.state": deque(maxlen=model.n_obs_steps_imagen), + "action": deque(maxlen=args.video_length), + } + + # Obtain initial frame and state + with profiler.profile_section("prepare_init_input"): + start_idx = 0 + model_input_fs = ori_fps // fs + batch, ori_state_dim, ori_action_dim = prepare_init_input( + start_idx, + init_frame_path, + transition_dict, + fs, + data.test_datasets[args.dataset], + n_obs_steps=model.n_obs_steps_imagen) + observation = { + 'observation.images.top': + batch['observation.image'].permute(1, 0, 2, + 3)[-1].unsqueeze(0), + 'observation.state': + batch['observation.state'][-1].unsqueeze(0), + 'action': + torch.zeros_like(batch['action'][-1]).unsqueeze(0) + } + observation = { + key: observation[key].to(device, non_blocking=True) + for key in observation + } + # Update observation queues + cond_obs_queues = populate_queues(cond_obs_queues, observation) + + # Setup PyTorch profiler context if enabled + pytorch_prof_ctx = nullcontext() + if use_pytorch_profiler: + pytorch_prof_ctx = profiler.start_pytorch_profiler( + wait=1, warmup=1, active=profile_active_iters + ) + + # Multi-round interaction with the world-model + with pytorch_prof_ctx: + for itr in tqdm(range(args.n_iter)): + profiler.current_iteration = itr + profiler.record_memory(f"iter_{itr}_start") + + with profiler.profile_section("iteration_total"): + # Get observation + with profiler.profile_section("prepare_observation"): + observation = { + 'observation.images.top': + torch.stack(list( + cond_obs_queues['observation.images.top']), + dim=1).permute(0, 2, 1, 3, 4), + 'observation.state': + torch.stack(list(cond_obs_queues['observation.state']), + dim=1), + 'action': + torch.stack(list(cond_obs_queues['action']), dim=1), + } + observation = { + key: observation[key].to(device, non_blocking=True) + for key in observation + } + + # Use world-model in policy to generate action + print(f'>>> Step {itr}: generating actions ...') + with profiler.profile_section("action_generation"): + pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( + model, + sample['instruction'], + observation, + noise_shape, + action_cond_step=args.exe_steps, + ddim_steps=args.ddim_steps, + ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args. + unconditional_guidance_scale, + fs=model_input_fs, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + sim_mode=False) + + # Update future actions in the observation queues + with profiler.profile_section("update_action_queues"): + for act_idx in range(len(pred_actions[0])): + obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]} + obs_update['action'][:, ori_action_dim:] = 0.0 + cond_obs_queues = populate_queues(cond_obs_queues, + obs_update) + + # Collect data for interacting the world-model using the predicted actions + with profiler.profile_section("prepare_wm_observation"): + observation = { + 'observation.images.top': + torch.stack(list( + cond_obs_queues['observation.images.top']), + dim=1).permute(0, 2, 1, 3, 4), + 'observation.state': + torch.stack(list(cond_obs_queues['observation.state']), + dim=1), + 'action': + torch.stack(list(cond_obs_queues['action']), dim=1), + } + observation = { + key: observation[key].to(device, non_blocking=True) + for key in observation + } + + # Interaction with the world-model + print(f'>>> Step {itr}: interacting with world model ...') + with profiler.profile_section("world_model_interaction"): + pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( + model, + "", + observation, + noise_shape, + action_cond_step=args.exe_steps, + ddim_steps=args.ddim_steps, + ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args. + unconditional_guidance_scale, + fs=model_input_fs, + text_input=False, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale) + + with profiler.profile_section("update_state_queues"): + for step_idx in range(args.exe_steps): + obs_update = { + 'observation.images.top': + pred_videos_1[0][:, step_idx:step_idx + 1].permute(1, 0, 2, 3), + 'observation.state': + torch.zeros_like(pred_states[0][step_idx:step_idx + 1]) if + args.zero_pred_state else pred_states[0][step_idx:step_idx + 1], + 'action': + torch.zeros_like(pred_actions[0][-1:]) + } + obs_update['observation.state'][:, ori_state_dim:] = 0.0 + cond_obs_queues = populate_queues(cond_obs_queues, + obs_update) + + # Save the imagen videos for decision-making + with profiler.profile_section("save_results"): + sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" + log_to_tensorboard(writer, + pred_videos_0, + sample_tag, + fps=args.save_fps) + # Save videos environment changes via world-model interaction + sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" + log_to_tensorboard(writer, + pred_videos_1, + sample_tag, + fps=args.save_fps) + + # Save the imagen videos for decision-making + sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' + save_results(pred_videos_0.cpu(), + sample_video_file, + fps=args.save_fps) + # Save videos environment changes via world-model interaction + sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' + save_results(pred_videos_1.cpu(), + sample_video_file, + fps=args.save_fps) + + print('>' * 24) + # Collect the result of world-model interactions + wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) + + profiler.record_memory(f"iter_{itr}_end") + profiler.step_profiler() + + full_video = torch.cat(wm_video, dim=2) + sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" + log_to_tensorboard(writer, + full_video, + sample_tag, + fps=args.save_fps) + sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" + save_results(full_video, sample_full_video_file, fps=args.save_fps) + + # Save profiling results + profiler.save_results() + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--savedir", + type=str, + default=None, + help="Path to save the results.") + parser.add_argument("--ckpt_path", + type=str, + default=None, + help="Path to the model checkpoint.") + parser.add_argument("--config", + type=str, + help="Path to the model checkpoint.") + parser.add_argument( + "--prompt_dir", + type=str, + default=None, + help="Directory containing videos and corresponding prompts.") + parser.add_argument("--dataset", + type=str, + default=None, + help="the name of dataset to test") + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="Number of DDIM steps. If non-positive, DDPM is used instead.") + parser.add_argument( + "--ddim_eta", + type=float, + default=1.0, + help="Eta for DDIM sampling. Set to 0.0 for deterministic results.") + parser.add_argument("--bs", + type=int, + default=1, + help="Batch size for inference. Must be 1.") + parser.add_argument("--height", + type=int, + default=320, + help="Height of the generated images in pixels.") + parser.add_argument("--width", + type=int, + default=512, + help="Width of the generated images in pixels.") + parser.add_argument( + "--frame_stride", + type=int, + nargs='+', + required=True, + help= + "frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)" + ) + parser.add_argument( + "--unconditional_guidance_scale", + type=float, + default=1.0, + help="Scale for classifier-free guidance during sampling.") + parser.add_argument("--seed", + type=int, + default=123, + help="Random seed for reproducibility.") + parser.add_argument("--video_length", + type=int, + default=16, + help="Number of frames in the generated video.") + parser.add_argument("--num_generation", + type=int, + default=1, + help="seed for seed_everything") + parser.add_argument( + "--timestep_spacing", + type=str, + default="uniform", + help= + "Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." + ) + parser.add_argument( + "--guidance_rescale", + type=float, + default=0.0, + help= + "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." + ) + parser.add_argument( + "--perframe_ae", + action='store_true', + default=False, + help= + "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." + ) + parser.add_argument( + "--n_action_steps", + type=int, + default=16, + help="num of samples per prompt", + ) + parser.add_argument( + "--exe_steps", + type=int, + default=16, + help="num of samples to execute", + ) + parser.add_argument( + "--n_iter", + type=int, + default=40, + help="num of iteration to interact with the world model", + ) + parser.add_argument("--zero_pred_state", + action='store_true', + default=False, + help="not using the predicted states as comparison") + parser.add_argument("--save_fps", + type=int, + default=8, + help="fps for the saving video") + # Profiling arguments + parser.add_argument( + "--profile", + action='store_true', + default=False, + help="Enable performance profiling (macro and operator-level analysis)." + ) + parser.add_argument( + "--profile_output_dir", + type=str, + default=None, + help="Directory to save profiling results. Defaults to {savedir}/profile_output." + ) + parser.add_argument( + "--profile_iterations", + type=int, + default=3, + help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis." + ) + return parser + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + seed = args.seed + if seed < 0: + seed = random.randint(0, 2**31) + seed_everything(seed) + + # Initialize profiler + profile_output_dir = args.profile_output_dir + if profile_output_dir is None: + profile_output_dir = os.path.join(args.savedir, "profile_output") + init_profiler(enabled=args.profile, output_dir=profile_output_dir) + + rank, gpu_num = 0, 1 + run_inference(args, gpu_num, rank) diff --git a/scripts/run_base_model_inference.sh b/scripts/run_base_model_inference.sh new file mode 100644 index 0000000..7740d14 --- /dev/null +++ b/scripts/run_base_model_inference.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +model_name=base_model +ckpt=/path/to/base/model +config=configs/inference/base_model_inference.yaml +res_dir="/path/to/result/directory" +seed=123 + +CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/base_model_inference.py \ +--seed ${seed} \ +--ckpt_path $ckpt \ +--config $config \ +--savedir "${res_dir}/videos" \ +--bs 1 --height 320 --width 512 \ +--unconditional_guidance_scale 1.0 \ +--ddim_steps 16 \ +--ddim_eta 1.0 \ +--prompt_dir "/path/to/examples/base_model_prompts" \ +--text_input \ +--video_length 16 \ +--timestep_spacing 'uniform_trailing' \ +--guidance_rescale 0.7 \ +--perframe_ae diff --git a/scripts/run_real_eval_server.sh b/scripts/run_real_eval_server.sh new file mode 100644 index 0000000..f3c1947 --- /dev/null +++ b/scripts/run_real_eval_server.sh @@ -0,0 +1,26 @@ +model_name=testing +ckpt=/path/to/model/checkpoint +config=configs/inference/world_model_decision_making.yaml +seed=123 +res_dir="path/to/results/directory" +datasets=( + "unitree_g1_pack_camera" +) + + +for dataset in "${datasets[@]}"; do + CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/real_eval_server.py \ + --seed ${seed} \ + --ckpt_path $ckpt \ + --config $config \ + --savedir "${res_dir}/${dataset}/${model_name}/videos" \ + --bs 1 --height 320 --width 512 \ + --unconditional_guidance_scale 1.0 \ + --ddim_steps 16 \ + --ddim_eta 1.0 \ + --video_length 16 \ + --frame_stride 2 \ + --timestep_spacing 'uniform_trailing' \ + --guidance_rescale 0.7 \ + --perframe_ae +done diff --git a/scripts/run_world_model_interaction.sh b/scripts/run_world_model_interaction.sh new file mode 100644 index 0000000..e8fc586 --- /dev/null +++ b/scripts/run_world_model_interaction.sh @@ -0,0 +1,42 @@ +model_name=testing +ckpt=/path/to/model/checkpoint +config=configs/inference/world_model_interaction.yaml +seed=123 +res_dir="/path/to/result/directory" + +datasets=( + "unitree_z1_stackbox" + "unitree_z1_dual_arm_stackbox" + "unitree_z1_dual_arm_stackbox_v2" + "unitree_z1_dual_arm_cleanup_pencils" + "unitree_g1_pack_camera" +) + +n_iters=(12 7 11 8 11) +fses=(4 4 4 4 6) + +for i in "${!datasets[@]}"; do + dataset=${datasets[$i]} + n_iter=${n_iters[$i]} + fs=${fses[$i]} + + CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ + --seed ${seed} \ + --ckpt_path $ckpt \ + --config $config \ + --savedir "${res_dir}/${model_name}/${dataset}" \ + --bs 1 --height 320 --width 512 \ + --unconditional_guidance_scale 1.0 \ + --ddim_steps 50 \ + --ddim_eta 1.0 \ + --prompt_dir "/path/to/unifolm-world-model-action/examples/world_model_interaction_prompts" \ + --dataset ${dataset} \ + --video_length 16 \ + --frame_stride ${fs} \ + --n_action_steps 16 \ + --exe_steps 16 \ + --n_iter ${n_iter} \ + --timestep_spacing 'uniform_trailing' \ + --guidance_rescale 0.7 \ + --perframe_ae +done diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100644 index 0000000..d37ee79 --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,32 @@ +# NCCL configuration +# export NCCL_DEBUG=debug +# export NCCL_IB_DISABLE=0 +# export NCCL_IB_GID_INDEX=3 +# export NCCL_NET_GDR_LEVEL=3 +# export CUDA_LAUNCH_BLOCKING=1 + +# export NCCL_TOPO_FILE=/tmp/topo.txt +# export MASTER_ADDR="master.ip." +# export MASTER_PROT=12366 + + +# args +name="experiment_name" +config_file=configs/train/config.yaml + +# save root dir for logs, checkpoints, tensorboard record, etc. +save_root="/path/to/savedir" + +mkdir -p $save_root/$name + +## run +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ +--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=12366 --node_rank=0 \ +./scripts/trainer.py \ +--base $config_file \ +--train \ +--name $name \ +--logdir $save_root \ +--devices 8 \ +--total_gpus=8 \ +lightning.trainer.num_nodes=1 diff --git a/scripts/trainer.py b/scripts/trainer.py new file mode 100644 index 0000000..87c6820 --- /dev/null +++ b/scripts/trainer.py @@ -0,0 +1,214 @@ +import argparse, os, datetime +import pytorch_lightning as pl +import torch + +from omegaconf import OmegaConf +from transformers import logging as transf_logging +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer import Trainer + +from unifolm_wma.utils.utils import instantiate_from_config +from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy +from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters + + +def get_parser(**parser_kwargs): + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument("--seed", + "-s", + type=int, + default=20250912, + help="seed for seed_everything") + parser.add_argument("--name", + "-n", + type=str, + default="", + help="experiment name, as saving folder") + parser.add_argument( + "--base", + "-b", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right.", + default=list()) + parser.add_argument("--train", + "-t", + action='store_true', + default=False, + help='train') + parser.add_argument("--val", + "-v", + action='store_true', + default=False, + help='val') + parser.add_argument("--test", + action='store_true', + default=False, + help='test') + parser.add_argument("--logdir", + "-l", + type=str, + default="logs", + help="directory for logging dat shit") + parser.add_argument("--auto_resume", + action='store_true', + default=False, + help="resume from full-info checkpoint") + parser.add_argument("--auto_resume_weight_only", + action='store_true', + default=False, + help="resume from weight-only checkpoint") + parser.add_argument("--debug", + "-d", + action='store_true', + default=False, + help="enable post-mortem debugging") + return parser + + +def get_nondefault_trainer_args(args): + parser = argparse.ArgumentParser() + parser = Trainer.add_argparse_args(parser) + default_trainer_args = parser.parse_args([]) + return sorted(k for k in vars(default_trainer_args) + if getattr(args, k) != getattr(default_trainer_args, k)) + + +if __name__ == "__main__": + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + local_rank = int(os.environ.get('LOCAL_RANK')) + global_rank = int(os.environ.get('RANK')) + num_rank = int(os.environ.get('WORLD_SIZE')) + + parser = get_parser() + # Extends existing argparse by default Trainer attributes + parser = Trainer.add_argparse_args(parser) + args, unknown = parser.parse_known_args() + transf_logging.set_verbosity_error() + seed_everything(args.seed) + + configs = [OmegaConf.load(cfg) for cfg in args.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + lightning_config = config.pop("lightning", OmegaConf.create()) + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + + # Setup workspace directories + workdir, ckptdir, cfgdir, loginfo = init_workspace(args.name, args.logdir, + config, + lightning_config, + global_rank) + logger = set_logger( + logfile=os.path.join(loginfo, 'log_%d:%s.txt' % (global_rank, now))) + logger.info("@lightning version: %s [>=1.8 required]" % (pl.__version__)) + logger.info("***** Configing Model *****") + config.model.params.logdir = workdir + model = instantiate_from_config(config.model) + # Load checkpoints + model = load_checkpoints(model, config.model) + + # Register_schedule again to make ZTSNR work + if model.rescale_betas_zero_snr: + model.register_schedule(given_betas=model.given_betas, + beta_schedule=model.beta_schedule, + timesteps=model.timesteps, + linear_start=model.linear_start, + linear_end=model.linear_end, + cosine_s=model.cosine_s) + + # Update trainer config + for k in get_nondefault_trainer_args(args): + trainer_config[k] = getattr(args, k) + + num_nodes = trainer_config.num_nodes + ngpu_per_node = trainer_config.devices + logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs") + + # Setup learning rate + base_lr = config.model.base_learning_rate + bs = config.data.params.batch_size + if getattr(config.model, 'scale_lr', True): + model.learning_rate = num_rank * bs * base_lr + else: + model.learning_rate = base_lr + + logger.info("***** Configing Data *****") + data = instantiate_from_config(config.data) + data.setup() + for k in data.train_datasets: + logger.info( + f"{k}, {data.train_datasets[k].__class__.__name__}, {len(data.train_datasets[k])}" + ) + if hasattr(data, 'val_datasets'): + for k in data.val_datasets: + logger.info( + f"{k}, {data.val_datasets[k].__class__.__name__}, {len(data.val_datasets[k])}" + ) + + for item in unknown: + if item.startswith('--total_gpus'): + num_gpus = int(item.split('=')[-1]) + break + model.datasets_len = len(data) + + logger.info("***** Configing Trainer *****") + if "accelerator" not in trainer_config: + trainer_config["accelerator"] = "gpu" + + # Setup trainer args: pl-logger and callbacks + trainer_kwargs = dict() + trainer_kwargs["num_sanity_val_steps"] = 0 + logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + + # Setup callbacks + callbacks_cfg = get_trainer_callbacks(lightning_config, config, workdir, + ckptdir, logger) + trainer_kwargs["callbacks"] = [ + instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg + ] + strategy_cfg = get_trainer_strategy(lightning_config) + trainer_kwargs["strategy"] = strategy_cfg if type( + strategy_cfg) == str else instantiate_from_config(strategy_cfg) + trainer_kwargs['precision'] = lightning_config.get('precision', 32) + trainer_kwargs["sync_batchnorm"] = False + + # Trainer config: others + trainer_args = argparse.Namespace(**trainer_config) + trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs) + + # Allow checkpointing via USR1 + def melk(*args, **kwargs): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt") + trainer.save_checkpoint(ckpt_path) + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb + pudb.set_trace() + + import signal + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + + # List the key model sizes + total_params = get_num_parameters(model) + + logger.info("***** Running the Loop *****") + if args.train: + try: + if "strategy" in lightning_config and lightning_config[ + 'strategy'].startswith('deepspeed'): + logger.info("") + if trainer_kwargs['precision'] == 16: + with torch.cuda.amp.autocast(): + trainer.fit(model, data) + else: + trainer.fit(model, data) + else: + logger.info("") + trainer.fit(model, data) + except Exception: + raise diff --git a/src/unifolm_wma/__init__.py b/src/unifolm_wma/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/unifolm_wma/data/base.py b/src/unifolm_wma/data/base.py new file mode 100644 index 0000000..b10b13a --- /dev/null +++ b/src/unifolm_wma/data/base.py @@ -0,0 +1,26 @@ +from abc import abstractmethod +from torch.utils.data import IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print( + f'{self.__class__.__name__} dataset contains {self.__len__()} examples.' + ) + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass diff --git a/src/unifolm_wma/data/normolize.py b/src/unifolm_wma/data/normolize.py new file mode 100644 index 0000000..98f8d7e --- /dev/null +++ b/src/unifolm_wma/data/normolize.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor, nn +from typing import Dict, List + + +def create_stats_buffers( + shapes: Dict[str, List[int]], + modes: Dict[str, str], + stats: Dict[str, Dict[str, Tensor]] = None, +) -> Dict[str, Dict[str, nn.ParameterDict]]: + """ + Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max + statistics. + + Args: (see Normalize and Unnormalize) + + Returns: + Dict: A Dictionary where keys are modalities and values are `nn.ParameterDict` containing + `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation. + """ + stats_buffers = {} + + for key, mode in modes.items(): + assert mode in ["mean_std", "min_max"] + + shape = tuple(shapes[key]) + + if "image" in key: + # sanity checks + assert len( + shape) == 3, f"number of dimensions of {key} != 3 ({shape=}" + c, h, w = shape + assert c < h and c < w, f"{key} is not channel first ({shape=})" + # override image shape to be invariant to height and width + shape = (c, 1, 1) + + # Note: we initialize mean, std, min, max to infinity. They should be overwritten + # downstream by `stats` or `policy.load_state_Dict`, as expected. During forward, + # we assert they are not infinity anymore. + + if "action" in key: + target_key = "action" + elif "state" in key: + target_key = 'observation.state' + else: + target_key = key + + buffer = {} + if mode == "mean_std": + mean = torch.ones(shape, dtype=torch.float32) * torch.inf + std = torch.ones(shape, dtype=torch.float32) * torch.inf + buffer = nn.ParameterDict({ + "mean": + nn.Parameter(mean, requires_grad=False), + "std": + nn.Parameter(std, requires_grad=False), + }) + elif mode == "min_max": + min = torch.ones(shape, dtype=torch.float32) * torch.inf + max = torch.ones(shape, dtype=torch.float32) * torch.inf + buffer = nn.ParameterDict({ + "min": + nn.Parameter(min, requires_grad=False), + "max": + nn.Parameter(max, requires_grad=False), + }) + + if stats is not None: + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + if mode == "mean_std": + buffer["mean"].data = stats[target_key]["mean"].clone() + buffer["std"].data = stats[target_key]["std"].clone() + elif mode == "min_max": + buffer["min"].data = stats[target_key]["min"].clone() + buffer["max"].data = stats[target_key]["max"].clone() + + stats_buffers[key] = buffer + return stats_buffers + + +def _no_stats_error_str(name: str) -> str: + return ( + f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a " + "pretrained model.") + + +class Normalize(nn.Module): + """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" + + def __init__( + self, + shapes: Dict[str, List[int]], + modes: Dict[str, str], + stats: Dict[str, Dict[str, Tensor]] = None, + ): + """ + Args: + shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values + are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing + mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape + is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. + modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values + are their normalization modes among: + - "mean_std": subtract the mean and divide by standard deviation. + - "min_max": map to [-1, 1] range. + stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image") + and values are Dictionaries of statistic types and their values (e.g. + `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for + training the model for the first time, these statistics will overwrite the default buffers. If + not provided, as expected for finetuning or evaluation, the default buffers should to be + overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the + dataset is not needed to get the stats, since they are already in the policy state_Dict. + """ + super().__init__() + self.shapes = shapes + self.modes = modes + self.stats = stats + stats_buffers = create_stats_buffers(shapes, modes, stats) + for key, buffer in stats_buffers.items(): + setattr(self, "buffer_" + key.replace(".", "_"), buffer) + + @torch.no_grad() + def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: + for key, mode in self.modes.items(): + if key not in batch: + continue + + buffer = getattr(self, "buffer_" + key.replace(".", "_")) + + if mode == "mean_std": + mean = buffer["mean"] + std = buffer["std"] + + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = (batch[key] - mean) / (std + 1e-8) + elif mode == "min_max": + min = buffer["min"] + max = buffer["max"] + + assert not torch.isinf(min).any(), _no_stats_error_str("min") + assert not torch.isinf(max).any(), _no_stats_error_str("max") + # normalize to [0,1] + batch[key] = (batch[key] - min) / (max - min + 1e-8) + # normalize to [-1, 1] + batch[key] = batch[key] * 2 - 1 + else: + raise ValueError(mode) + return batch + + +class Unnormalize(nn.Module): + """ + Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their + original range used by the environment. + """ + + def __init__( + self, + shapes: Dict[str, List[int]], + modes: Dict[str, str], + stats: Dict[str, Dict[str, Tensor]] = None, + ): + """ + Args: + shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values + are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing + mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape + is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. + modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values + are their normalization modes among: + - "mean_std": subtract the mean and divide by standard deviation. + - "min_max": map to [-1, 1] range. + stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image") + and values are Dictionaries of statistic types and their values (e.g. + `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for + training the model for the first time, these statistics will overwrite the default buffers. If + not provided, as expected for finetuning or evaluation, the default buffers should to be + overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the + dataset is not needed to get the stats, since they are already in the policy state_Dict. + """ + super().__init__() + self.shapes = shapes + self.modes = modes + self.stats = stats + stats_buffers = create_stats_buffers(shapes, modes, stats) + for key, buffer in stats_buffers.items(): + setattr(self, "buffer_" + key.replace(".", "_"), buffer) + + @torch.no_grad() + def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: + for key, mode in self.modes.items(): + if key not in batch: + continue + + buffer = getattr(self, "buffer_" + key.replace(".", "_")) + + if mode == "mean_std": + mean = buffer["mean"] + std = buffer["std"] + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = batch[key] * std + mean + elif mode == "min_max": + min = buffer["min"] + max = buffer["max"] + assert not torch.isinf(min).any(), _no_stats_error_str("min") + assert not torch.isinf(max).any(), _no_stats_error_str("max") + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max - min) + min + else: + raise ValueError(mode) + return batch diff --git a/src/unifolm_wma/data/utils.py b/src/unifolm_wma/data/utils.py new file mode 100644 index 0000000..55675ae --- /dev/null +++ b/src/unifolm_wma/data/utils.py @@ -0,0 +1,60 @@ +import torch + +from huggingface_hub import hf_hub_download, snapshot_download +from typing import Dict, List, Union +from pathlib import Path +from safetensors.torch import load_file + +def unflatten_dict(d, sep="/"): + outdict = {} + for key, value in d.items(): + parts = key.split(sep) + d = outdict + for part in parts[:-1]: + if part not in d: + d[part] = {} + d = d[part] + d[parts[-1]] = value + return outdict + + +def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]: + """episode_data_index contains the range of indices for each episode + + Example: + ```python + from_id = episode_data_index["from"][episode_id].item() + to_id = episode_data_index["to"][episode_id].item() + episode_frames = [dataset[i] for i in range(from_id, to_id)] + ``` + """ + if root is not None: + path = Path( + root) / repo_id / "meta_data" / "episode_data_index.safetensors" + else: + path = hf_hub_download(repo_id, + "meta_data/episode_data_index.safetensors", + repo_type="dataset", + revision=version) + + return load_file(path) + + +def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]: + """stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std + + Example: + ```python + normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"] + ``` + """ + if root is not None: + path = Path(root) / repo_id / "meta_data" / "stats.safetensors" + else: + path = hf_hub_download(repo_id, + "meta_data/stats.safetensors", + repo_type="dataset", + revision=version) + + stats = load_file(path) + return unflatten_dict(stats) diff --git a/src/unifolm_wma/data/wma_data.py b/src/unifolm_wma/data/wma_data.py new file mode 100644 index 0000000..bf64b32 --- /dev/null +++ b/src/unifolm_wma/data/wma_data.py @@ -0,0 +1,408 @@ +import torch +import os +import random +import pandas as pd +import h5py + +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision import transforms +from pathlib import Path + +from unifolm_wma.data.utils import load_stats +from unifolm_wma.data.normolize import Normalize, Unnormalize + + +class WMAData(Dataset): + """ + Assuming the following dataset structure: + dataset_dir/ + ├── videos + │ ├──dataset_name + │ │ ├──camera_view_dir + │ │ ├── 0.mp4 + │ │ ├── 1.mp4 + │ │ └── ... + │ └── ... + ├── transitions + │ ├── dataset_name + │ ├── meta_data + │ ├── 0.h5 + │ ├── 1.h5 + │ └── ... + └── dataset_name.csv + """ + + def __init__( + self, + meta_path, + data_dir, + subsample=None, + video_length=16, + resolution=[256, 512], + frame_stride=1, + frame_stride_min=1, + spatial_transform=None, + crop_resolution=None, + fps_max=None, + load_raw_resolution=False, + fixed_fps=None, + random_fs=False, + cond_robot_label_prob=0.0, + transition_dir=None, + dataset_name=None, + normalization_mode='min_max', + individual_normalization=False, + n_obs_steps=1, + max_action_dim=7, + max_state_dim=7, + ): + self.meta_path = meta_path + self.data_dir = data_dir + self.subsample = subsample + self.video_length = video_length + self.resolution = [resolution, resolution] if isinstance( + resolution, int) else resolution + self.fps_max = fps_max + self.frame_stride = frame_stride + self.frame_stride_min = frame_stride_min + self.fixed_fps = fixed_fps + self.load_raw_resolution = load_raw_resolution + self.random_fs = random_fs + self.cond_robot_label_prob = cond_robot_label_prob + self.transition_dir = transition_dir + self.dataset_name = dataset_name + self.max_action_dim = max_action_dim + self.max_state_dim = max_state_dim + + self._load_metadata() + if spatial_transform is not None: + if spatial_transform == "random_crop": + self.spatial_transform = transforms.RandomCrop(crop_resolution) + elif spatial_transform == "center_crop": + self.spatial_transform = transforms.Compose([ + transforms.CenterCrop(resolution), + ]) + elif spatial_transform == "resize_center_crop": + self.spatial_transform = transforms.Compose([ + transforms.Resize(min(self.resolution)), + transforms.CenterCrop(self.resolution), + ]) + elif spatial_transform == "resize": + self.spatial_transform = transforms.Resize(self.resolution) + else: + raise NotImplementedError + else: + self.spatial_transform = None + + self.normalization_mode = normalization_mode + self.individual_normalization = individual_normalization + self.n_obs_steps = n_obs_steps + self._load_stats() + if individual_normalization: + self._init_normalizers() + + def _load_metadata(self): + metadata = pd.read_csv(self.meta_path, dtype=str) + if self.subsample is not None: + metadata = metadata.sample(self.subsample, random_state=0) + + self.metadata = metadata + # drop the rows contain NaN values + self.metadata.dropna(inplace=True) + print( + f">>> {metadata['data_dir'].iloc[0]}: {len(metadata)} data samples loaded." + ) + + def _load_stats(self): + self.stats = load_stats(self.dataset_name, None, self.transition_dir) + print(f">>> {self.metadata['data_dir'].iloc[0]}: data stats loaded.") + + def _init_normalizers(self): + shape_dict = { + 'pre_action': [self.stats['action']['max'].shape[-1]], + 'action': [self.stats['action']['max'].shape[-1]], + 'observation.state': + [self.stats['observation.state']['max'].shape[-1]], + 'next.state': [self.stats['observation.state']['max'].shape[-1]] + } + normalization_mode_dict = { + 'pre_action': self.normalization_mode, + 'action': self.normalization_mode, + 'observation.state': self.normalization_mode, + 'next.state': self.normalization_mode + } + self.normalizer = Normalize(shape_dict, normalization_mode_dict, + self.stats) + self.unnormalizer = Unnormalize(shape_dict, normalization_mode_dict, + self.stats) + print( + f">>> {self.metadata['data_dir'].iloc[0]}: normalizer initiated.") + + def _get_video_path(self, sample): + rel_video_fp = os.path.join(sample['data_dir'], + str(sample['videoid']) + '.mp4') + full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) + return full_video_fp + + def _get_transition_path(self, sample): + data_dir = Path(sample['data_dir']) + if self.dataset_name == data_dir.name: + rel_transition_fp = os.path.join(str(data_dir), + str(sample['videoid']) + '.h5') + else: + rel_transition_fp = os.path.join(str(data_dir.parent), + str(sample['videoid']) + '.h5') + full_transition_fp = os.path.join(self.data_dir, 'transitions', + rel_transition_fp) + return full_transition_fp + + def get_uni_vec(self, action_state_dict, action_type, state_type): + if 'pre_action' in action_state_dict: + action_state_dict['pre_action'], _ = self._map_to_uni_action( + action_state_dict['pre_action'], action_type) + if 'action' in action_state_dict: + action_state_dict['action'], action_state_dict[ + 'action_mask'] = self._map_to_uni_action( + action_state_dict['action'], action_type) + if 'observation.state' in action_state_dict: + action_state_dict['observation.state'], _ = self._map_to_uni_state( + action_state_dict['observation.state'], state_type) + if 'next.state' in action_state_dict: + action_state_dict['next.state'], action_state_dict[ + 'state_mask'] = self._map_to_uni_state( + action_state_dict['next.state'], state_type) + return action_state_dict + + def _map_to_uni_action(self, action, action_type): + action_dim = action.shape[-1] + uni_action = torch.nn.functional.pad( + action, (0, self.max_action_dim - action_dim), + mode='constant', + value=0) + uni_action_mask = torch.zeros_like(uni_action) + uni_action_mask[:, :action_dim] = 1 + return uni_action, uni_action_mask + + def _map_to_uni_state(self, state, state_type): + state_dim = state.shape[-1] + uni_state = torch.nn.functional.pad( + state, (0, self.max_state_dim - state_dim), + mode='constant', + value=0) + uni_state_mask = torch.zeros_like(uni_state) + uni_state_mask[:, :state_dim] = 1 + return uni_state, uni_state_mask + + def __getitem__(self, index): + + if self.random_fs: + frame_stride = random.randint(self.frame_stride_min, + self.frame_stride) + else: + frame_stride = self.frame_stride + + # Get frames until success + while True: + index = index % len(self.metadata) + sample = self.metadata.iloc[index] + video_path = self._get_video_path(sample) + + instruction = sample['instruction'] + if self.cond_robot_label_prob > 0.0 and random.random( + ) < self.cond_robot_label_prob: + if sample['embodiment'] != 'x': + instruction = sample['embodiment'] + ' [SEP] ' + sample[ + 'instruction'] + try: + if self.load_raw_resolution: + video_reader = VideoReader(video_path, ctx=cpu(0)) + else: + video_reader = VideoReader(video_path, + ctx=cpu(0), + width=530, + height=300) + if len(video_reader) < self.video_length: + print( + f">>> Video length ({len(video_reader)}) is smaller than target length({self.video_length})" + ) + index += 1 + continue + else: + pass + except: + index += 1 + print(f">>> Error: load video failed! path = {video_path}") + continue + + fps_ori = video_reader.get_avg_fps() + if self.fixed_fps is not None: + frame_stride = int(frame_stride * + (1.0 * fps_ori / self.fixed_fps)) + + # To avoid extreme cases when fixed_fps is used + frame_stride = max(frame_stride, 1) + + # Get valid range (adapting case by case) + required_frame_num = frame_stride * (self.video_length - 1) + 1 + frame_num = len(video_reader) + if frame_num < required_frame_num: + # Drop extra samples if fixed fps is required + if self.fixed_fps is not None and frame_num < required_frame_num * 0.5: + index += 1 + continue + else: + frame_stride = frame_num // self.video_length + required_frame_num = frame_stride * (self.video_length - + 1) + 1 + + # Select a random clip + random_range = frame_num - required_frame_num + start_idx = random.randint( + 0, random_range - + frame_stride) if random_range - frame_stride > 0 else 0 + + # Calculate frame indices + frame_indices = [ + start_idx + frame_stride * i for i in range(self.video_length) + ] + try: + next_frame_indices = [ + idx + frame_stride for idx in frame_indices + ] + frames = video_reader.get_batch(next_frame_indices) + break + except: + print( + f">>> Error: Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]" + ) + index += 1 + continue + + # Load transition data + transition_path = self._get_transition_path(sample) + with h5py.File(transition_path, 'r') as h5f: + transition_dict = {} + for key in h5f.keys(): + transition_dict[key] = torch.tensor(h5f[key][()]) + for key in h5f.attrs.keys(): + transition_dict[key] = h5f.attrs[key] + + # Load observable states + if start_idx < self.n_obs_steps - 1: + state_indices = list(range(0, start_idx + 1)) + states = transition_dict['observation.state'][state_indices, :] + num_padding = self.n_obs_steps - 1 - start_idx + first_slice = states[0:1, :] # (t, d) + padding = first_slice.repeat(num_padding, 1) + states = torch.cat((padding, states), dim=0) + else: + state_indices = list( + range(start_idx - self.n_obs_steps + 1, start_idx + 1)) + states = transition_dict['observation.state'][state_indices, :] + assert states.shape[ + 0] == self.n_obs_steps, '>>> Do not have enough previous states as observation.' + + # Load observable actions + if start_idx < self.n_obs_steps: + pre_action_indices = list(range(0, start_idx)) + pre_actions = transition_dict['action'][pre_action_indices, :] + num_padding = self.n_obs_steps - start_idx + first_slice = torch.zeros_like(transition_dict['action'][:1, :]) + padding = first_slice.repeat(num_padding, 1) + pre_actions = torch.cat((padding, pre_actions), dim=0) + else: + pre_action_indices = list( + range(start_idx - self.n_obs_steps, start_idx)) + pre_actions = transition_dict['action'][pre_action_indices, :] + assert pre_actions.shape[ + 0] == self.n_obs_steps, ">>> Do not have enough previous actions as observation" + + # Load future actions + actions = transition_dict['action'][frame_indices, :] + # Load future states + next_state_indices = [idx + frame_stride for idx in frame_indices] + next_states = transition_dict['observation.state'][ + next_state_indices, :] + frames_action_state_dict = { + 'pre_action': pre_actions, + 'action': actions, + 'observation.state': states, + 'next.state': next_states + } + if self.individual_normalization: + frames_action_state_dict = self.normalizer( + frames_action_state_dict) + + # Update action and states to unified vector + frames_action_state_dict = self.get_uni_vec( + frames_action_state_dict, + transition_dict['action_type'], + transition_dict['state_type'], + ) + + # Load observable images + if start_idx < self.n_obs_steps - 1: + action_net_frame_indices = list(range(0, start_idx + 1)) + action_net_frames = video_reader.get_batch( + action_net_frame_indices) + action_net_frames = torch.tensor( + action_net_frames.asnumpy()).permute(0, 3, 1, 2).float() + first_slice = action_net_frames[0:1, :] + num_padding = self.n_obs_steps - 1 - start_idx + padding = first_slice.repeat(num_padding, 1, 1, 1) + action_net_frames = torch.cat((padding, action_net_frames), dim=0) + assert ( + action_net_frames.shape[0] == self.n_obs_steps + ), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}' + action_net_frames = action_net_frames.permute(1, 0, 2, 3) + else: + action_net_frame_indices = list( + range(start_idx - self.n_obs_steps + 1, start_idx + 1)) + action_net_frames = video_reader.get_batch( + action_net_frame_indices) + assert ( + action_net_frames.shape[0] == self.n_obs_steps + ), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}' + action_net_frames = torch.tensor( + action_net_frames.asnumpy()).permute(3, 0, 1, 2).float() + + assert (frames.shape[0] == self.video_length + ), f'{len(frames)}, self.video_length={self.video_length}' + frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() + + if self.spatial_transform is not None: + frames = self.spatial_transform(frames) + action_net_frames = self.spatial_transform(action_net_frames) + + if self.resolution is not None: + assert (frames.shape[2], frames.shape[3]) == ( + self.resolution[0], self.resolution[1] + ), f'frames={frames.shape}, self.resolution={self.resolution}' + assert ( + action_net_frames.shape[2], action_net_frames.shape[3] + ) == ( + self.resolution[0], self.resolution[1] + ), f'action_net_frames={action_net_frames.shape}, self.resolution={self.resolution}' + + # Normalize frames tensors to [-1,1] + frames = (frames / 255 - 0.5) * 2 + action_net_frames = (action_net_frames / 255 - 0.5) * 2 + fps_clip = fps_ori // frame_stride + if self.fps_max is not None and fps_clip > self.fps_max: + fps_clip = self.fps_max + + data = { + 'video': frames, + 'instruction': instruction, + 'path': video_path, + 'fps': fps_clip, + 'frame_stride': frame_stride, + 'observation.image': action_net_frames, + } + data.update(frames_action_state_dict) + + return data + + def __len__(self): + return len(self.metadata) diff --git a/src/unifolm_wma/models/__init__.py b/src/unifolm_wma/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/unifolm_wma/models/autoencoder.py b/src/unifolm_wma/models/autoencoder.py new file mode 100644 index 0000000..2a3b521 --- /dev/null +++ b/src/unifolm_wma/models/autoencoder.py @@ -0,0 +1,267 @@ +import os +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from einops import rearrange +from unifolm_wma.modules.networks.ae_modules import Encoder, Decoder +from unifolm_wma.utils.distributions import DiagonalGaussianDistribution +from unifolm_wma.utils.utils import instantiate_from_config + + +class AutoencoderKL(pl.LightningModule): + + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + test=False, + logdir=None, + input_dim=4, + test_args=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], + 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + self.input_dim = input_dim + self.test = test + self.test_args = test_args + self.logdir = logdir + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", + torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + if self.test: + self.init_test() + + def init_test(self, ): + self.test = True + save_dir = os.path.join(self.logdir, "test") + if 'ckpt' in self.test_args: + ckpt_name = os.path.basename(self.test_args.ckpt).split( + '.ckpt')[0] + f'_epoch{self._cur_epoch}' + self.root = os.path.join(save_dir, ckpt_name) + else: + self.root = save_dir + if 'test_subdir' in self.test_args: + self.root = os.path.join(save_dir, self.test_args.test_subdir) + + self.root_zs = os.path.join(self.root, "zs") + self.root_dec = os.path.join(self.root, "reconstructions") + self.root_inputs = os.path.join(self.root, "inputs") + os.makedirs(self.root, exist_ok=True) + + if self.test_args.save_z: + os.makedirs(self.root_zs, exist_ok=True) + if self.test_args.save_reconstruction: + os.makedirs(self.root_dec, exist_ok=True) + if self.test_args.save_input: + os.makedirs(self.root_inputs, exist_ok=True) + assert (self.test_args is not None) + self.test_maximum = getattr(self.test_args, 'test_maximum', None) + self.count = 0 + self.eval_metrics = {} + self.decodes = [] + self.save_decode_samples = 2048 + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu") + try: + self._cur_epoch = sd['epoch'] + sd = sd["state_dict"] + except: + self._cur_epoch = 'null' + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x, **kwargs): + + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, **kwargs): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if x.dim() == 5 and self.input_dim == 4: + b, c, t, h, w = x.shape + self.b = b + self.t = t + x = rearrange(x, 'b c t h w -> (b t) c h w') + + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train") + self.log("aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True) + self.log_dict(log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train") + + self.log("discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True) + self.log_dict(log_dict_disc, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val") + + discloss, log_dict_disc = self.loss(inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, + betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", + torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/src/unifolm_wma/models/ddpms.py b/src/unifolm_wma/models/ddpms.py new file mode 100644 index 0000000..fbf2042 --- /dev/null +++ b/src/unifolm_wma/models/ddpms.py @@ -0,0 +1,2524 @@ +""" +wild mixture of +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import random +import torch +import torch.nn as nn +import copy +import numpy as np +import pytorch_lightning as pl +import torch.nn.functional as F +import logging + +mainlogger = logging.getLogger('mainlogger') + +from functools import partial +from contextlib import contextmanager +from tqdm import tqdm +from einops import rearrange, repeat, reduce +from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR +from torchvision.utils import make_grid +from pytorch_lightning.utilities import rank_zero_only +from omegaconf import OmegaConf +from typing import Optional, Sequence, Any, Tuple, Union, List, Dict +from collections.abc import Mapping, Iterable, Callable +from torch import Tensor + +from unifolm_wma.utils.utils import instantiate_from_config +from unifolm_wma.utils.ema import LitEma +from unifolm_wma.utils.distributions import DiagonalGaussianDistribution +from unifolm_wma.utils.diffusion import make_beta_schedule, rescale_zero_terminal_snr +from unifolm_wma.utils.basics import disabled_train +from unifolm_wma.utils.common import (extract_into_tensor, noise_like, exists, + default) + +from unifolm_wma.models.samplers.ddim import DDIMSampler +from unifolm_wma.models.diffusion_head.common.lr_scheduler import get_scheduler, SelectiveLRScheduler +from unifolm_wma.models.diffusion_head.ema_model import EMAModel +from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb +from unifolm_wma.modules.encoders.condition import MLPProjector +from unifolm_wma.data.normolize import Normalize, Unnormalize + +__conditioning_keys__ = { + 'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y' +} + + +class DDPM(pl.LightningModule): + """ + Denoising Diffusion Probabilistic Model (DDPM) LightningModule. + """ + + def __init__( + self, + wma_config: OmegaConf, + timesteps: int = 1000, + beta_schedule: str = "linear", + loss_type: str = "l2", + ckpt_path: Optional[str] = None, + ignore_keys: Optional[Sequence[str]] = [], + load_only_unet: bool = False, + monitor: str = None, + use_ema: bool = True, + first_stage_key: str = "image", + image_size: int = 256, + channels: int = 3, + log_every_t: int = 100, + clip_denoised: bool = True, + linear_start: float = 1e-4, + linear_end: float = 2e-2, + cosine_s: float = 8e-3, + given_betas: Optional[np.ndarray] = None, + original_elbo_weight: float = 0.0, + v_posterior: float = 0.0, + l_simple_weight: float = 1.0, + conditioning_key: Optional[str] = None, + parameterization: str = "eps", + scheduler_config: Optional[Mapping[str, Any]] = None, + use_positional_encodings: bool = False, + learn_logvar: bool = False, + logvar_init: float = 0.0, + rescale_betas_zero_snr: bool = False, + ): + """ + wma_config: Config object used to build the underlying model. + timesteps: Number of diffusion steps. + beta_schedule: Schedule type for betas (e.g., 'linear', 'cosine'). + loss_type: Loss type. + ckpt_path: Optional checkpoint path to restore weights. + ignore_keys: Keys to ignore when loading the checkpoint. + load_only_unet: If True, load the backbone into self.model only. + monitor: Metric key for monitoring. + use_ema: If True, maintain EMA weights. + first_stage_key: Key in batch dict for inputs. + image_size: Image size. + channels: Number of channels. + log_every_t: Interval of timesteps to log intermediates during sampling. + clip_denoised: Clamp x_0 predictions or not. + linear_start: Linear schedule start. + linear_end: Linear schedule end. + cosine_s: Cosine schedule s parameter. + given_betas: Externally provided betas; overrides schedule if set. + original_elbo_weight: Weight for VLB term. + v_posterior: Interpolation weight for posterior variance. + l_simple_weight: Weight for simple loss term. + conditioning_key: Conditioning mechanism key (if used by wrapper). + parameterization: One of {'eps','x0','v'}. + scheduler_config: Optional LR scheduler config. + use_positional_encodings: Whether to inject positional encodings. + learn_logvar: If True, learn per-timestep log-variance. + logvar_init: Initial value for log-variance. + rescale_betas_zero_snr: If True, apply zero-SNR rescaling to betas. + """ + super().__init__() + assert parameterization in [ + "eps", "x0", "v" + ], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + mainlogger.info( + f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" + ) + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.channels = channels + self.temporal_length = wma_config.params.temporal_length + self.image_size = image_size + if isinstance(self.image_size, int): + self.image_size = [self.image_size, self.image_size] + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(wma_config, conditioning_key) + self.use_ema = use_ema + self.rescale_betas_zero_snr = rescale_betas_zero_snr + if self.use_ema: + self.model_ema = LitEma(self.model) + mainlogger.info( + f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, + ignore_keys=ignore_keys, + only_model=load_only_unet) + self.register_schedule(given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + + # For reschedule + self.given_betas = given_betas + self.beta_schedule = beta_schedule + self.timesteps = timesteps + self.cosine_s = cosine_s + self.loss_type = loss_type + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, + size=(self.num_timesteps, )) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + def register_schedule(self, + given_betas: Optional[np.ndarray] = None, + beta_schedule: str = "linear", + timesteps: int = 1000, + linear_start: float = 1e-4, + linear_end: float = 2e-2, + cosine_s: float = 8e-3) -> None: + """ + Create and register diffusion buffers (betas, alphas, posteriors, weights). + + Args: + given_betas: If provided, use these instead of building a schedule. + beta_schedule: Name of schedule to create if betas not given. + timesteps: Number of diffusion steps. + linear_start: Linear schedule start. + linear_end: Linear schedule end. + cosine_s: Cosine schedule parameter + """ + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + if self.rescale_betas_zero_snr: + betas = rescale_zero_terminal_snr(betas) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[ + 0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', + to_torch(alphas_cumprod_prev)) + + self.register_buffer('sqrt_alphas_cumprod', + to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', + to_torch(np.log(1. - alphas_cumprod))) + + if self.parameterization != 'v': + self.register_buffer('sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod - 1))) + else: + self.register_buffer('sqrt_recip_alphas_cumprod', + torch.zeros_like(to_torch(alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', + torch.zeros_like(to_torch(alphas_cumprod))) + + posterior_variance = (1 - self.v_posterior) * betas * ( + 1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # Above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', + to_torch(posterior_variance)) + # Below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + 'posterior_log_variance_clipped', + to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer( + 'posterior_mean_coef1', + to_torch(betas * np.sqrt(alphas_cumprod_prev) / + (1. - alphas_cumprod))) + self.register_buffer( + 'posterior_mean_coef2', + to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / + (1. - alphas_cumprod))) + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / (2 * self.posterior_variance * + to_torch(alphas) * + (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / ( + 2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like( + self.betas**2 / + (2 * self.posterior_variance * to_torch(alphas) * + (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context: Optional[str] = None) -> Iterable[None]: + """ + Context manager that temporarily swaps to EMA weights (if enabled). + + Args: + context: Optional label for logging. + + """ + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + mainlogger.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + mainlogger.info(f"{context}: Restored training weights") + + def init_from_ckpt(self, + path: str, + ignore_keys: Sequence[str] = tuple(), + only_model: bool = False) -> None: + """ + Load a checkpoint, optionally filtering keys or loading only the inner model. + + Args: + path: Path to checkpoint. + ignore_keys: State-dict keys (prefix match) to drop. + only_model: If True, load only into self.model. + """ + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + mainlogger.info( + "Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict( + sd, + strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + mainlogger.info( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + mainlogger.info(f"Missing Keys: {missing}") + if len(unexpected) > 0: + mainlogger.info(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start: Tensor, + t: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compute q(x_t | x_0): mean, variance, and log-variance. + + Args: + x_start: the [N x C x ...] tensor of noiseless inputs.. + t: the number of diffusion steps (minus 1). Here, 0 means one step.. + + Returns: + (mean, variance, log_variance), each shaped like x_start. + """ + mean = ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, + x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, + t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t: Tensor, t: Tensor, + noise: Tensor) -> Tensor: + """ + Predict x_0 from x_t and noise. + """ + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * + x_t - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, + x_t.shape) * noise) + + def predict_start_from_z_and_v(self, x_t: Tensor, t: Tensor, + v: Tensor) -> Tensor: + """ + Predict x_0 from z and v (v-parameterization). + """ + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, + x_t.shape) * v) + + def predict_eps_from_z_and_v(self, x_t: Tensor, t: Tensor, + v: Tensor) -> Tensor: + """ + Predict epsilon from z and v (v-parameterization). + """ + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, + x_t.shape) * x_t) + + def q_posterior(self, x_start: Tensor, x_t: Tensor, + t: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compute posterior q(x_{t-1} | x_t, x_0): mean and (log-)variance. + """ + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * + x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t) + posterior_variance = extract_into_tensor(self.posterior_variance, t, + x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x: Tensor, t: Tensor, + clip_denoised: bool) -> Tuple[Tensor, Tensor, Tensor]: + """ + Predict mean and variance of p(x_{t-1} | x_t) using the model. + """ + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, + x: Tensor, + t: Tensor, + clip_denoised: bool = True, + repeat_noise: bool = False) -> Tensor: + """ + Draw a single reverse-diffusion sample step: x_{t-1} from x_t. + + Args: + x: Current noisy sample (B, C, ...). + t: Current timestep indices (B,). + clip_denoised: Clamp x_0 predictions or not. + repeat_noise: Reuse the same noise across the batch. + + Returns: + Next sample x_{t-1}. + """ + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # No noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape( + b, *((1, ) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * + model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop( + self, + shape: Sequence[int], + return_intermediates: bool = False + ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: + """ + Run the full reverse process starting from Gaussian noise. + + Args: + shape: Output tensor shape (B, C, ...). + return_intermediates: If True, also return intermediate frames. + + Returns: + Final sample, and optionally the intermediate list. + """ + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), + desc='Sampling t', + total=self.num_timesteps): + img = self.p_sample(img, + torch.full((b, ), + i, + device=device, + dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample( + self, + batch_size: int = 16, + return_intermediates: bool = False + ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: + """ + Convenience wrapper to sample square images of configured size. + + Args: + batch_size: Number of samples. + return_intermediates: If True, also return intermediate frames. + + Returns: + Final sample (and optionally intermediates). + """ + image_size = self.image_size + channels = self.channels + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, + x_start: Tensor, + t: Tensor, + noise: Optional[Tensor] = None) -> Tensor: + """ + Forward noising step: sample x_t ~ q(x_t | x_0). + """ + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise) + + def get_v(self, x: Tensor, noise: Tensor, t: Tensor) -> Tensor: + """ + Compute v-target given x and epsilon. + """ + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, + x.shape) * x) + + def get_loss(self, + pred: Tensor, + target: Tensor, + mean: bool = True) -> Tensor: + """ + Compute training loss between prediction and target. + + Args: + pred: Model output. + target: Target tensor. + mean: If True, reduce to mean. + + Returns: + Loss tensor (scalar if reduced). + """ + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, + pred, + reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, + x_start: Tensor, + t: Tensor, + noise: Optional[Tensor] = None + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """ + Compute the per-step training loss for a batch. + + Args: + x_start: Clean inputs (B, C, ...). + t: Timesteps (B,). + noise: Optional pre-sampled epsilon. + + Returns: + (loss, log_dict) + """ + + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError( + f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x: Tensor, *args: Any, + **kwargs: Any) -> Tuple[Tensor, Dict[str, Tensor]]: + """ + Lightning forward: sample random timesteps and compute losses. + + Args: + x: Clean batch (B, C, ...). + + Returns: + (loss, log_dict) + """ + t = torch.randint(0, + self.num_timesteps, (x.shape[0], ), + device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch: Mapping[str, Tensor], k: str) -> Tensor: + """ + Fetch and format the network input from batch. + + Args: + batch: Batch mapping. + k: Key for the tensor to use. + + Returns: + (B, C, ...) float32 contiguous tensor. + """ + x = batch[k] + ''' + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + ''' + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step( + self, batch: Mapping[str, + Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: + """ + Common train/val step computing loss and logs. + """ + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch: Mapping[str, Tensor], + batch_idx: int) -> Tensor: + """ + PyTorch Lightning training step. + """ + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True) + + self.log("global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', + lr, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False) + return loss + + @torch.no_grad() + def validation_step(self, batch: Mapping[str, Tensor], + batch_idx: int) -> None: + """ + PyTorch Lightning validation step with and without EMA. + """ + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = { + key + '_ema': loss_dict_ema[key] + for key in loss_dict_ema + } + self.log_dict(loss_dict_no_ema, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True) + self.log_dict(loss_dict_ema, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True) + + def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None: + """ + Update EMA after each train batch (if enabled). + """ + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples: List[Tensor]) -> Tensor: + """ + Arrange a list of (B, C, ...) tensors into a grid for logging. + + Args: + samples: List of tensors at different timesteps. + + Returns: + Grid image tensor suitable for visualization. + """ + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images( + self, + batch: Mapping[str, Tensor], + N: int = 8, + n_row: int = 2, + sample: bool = True, + return_keys: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> Dict[str, Tensor]: + """ + Create tensors for image logging: inputs, diffusion row, (optional) samples. + + Args: + batch: Batch mapping. + N: Number of examples to visualize. + n_row: Number of examples for diffusion-row visualization. + sample: If True, also run reverse diffusion to produce samples. + return_keys: If provided, filter the returned dict to these keys. + + Returns: + Dict of image tensors. + """ + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # Get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # Get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, + return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Build the optimizer (AdamW) over model parameters (+ logvar if learned). + """ + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """ + Main Class: Latent-diffusion model on top of DDPM (first/cond stages + guidance). + """ + + def __init__(self, + first_stage_config: OmegaConf, + cond_stage_config: OmegaConf, + num_timesteps_cond: int | None = None, + cond_stage_key: str = "instruction", + cond_stage_trainable: bool = False, + cond_stage_forward: str | None = None, + conditioning_key: str | None = None, + uncond_prob: float = 0.2, + uncond_type: str = "empty_seq", + scale_factor: str = 1.0, + scale_by_std: bool = False, + encoder_type: str = "2d", + only_model: bool = False, + noise_strength: float = 0.0, + use_dynamic_rescale: bool = False, + base_scale: float = 0.7, + turning_step: int = 400, + interp_mode: bool = False, + fps_condition_type: str = 'fs', + perframe_ae: bool = False, + logdir: str | None = None, + rand_cond_frame: bool = False, + en_and_decode_n_samples_a_time: int | None = None, + *args, + **kwargs): + """ + Args: + first_stage_config: OmegaConf config for the first-stage autoencoder. + cond_stage_config: OmegaConf config for the conditioning encoder/module. + num_timesteps_cond: Number of condition timesteps used for schedule shortening. + cond_stage_key: Batch key for conditioning input (e.g., "instruction"). + cond_stage_trainable: Whether the conditioning module is trainable. + cond_stage_forward: Optional method name to call on cond model instead of default. + conditioning_key: Conditioning mode (e.g., "crossattn", "concat"). + uncond_prob: Probability to drop/zero the condition for classifier-free guidance. + uncond_type: Strategy for unconditional condition ("zero_embed" or "empty_seq"). + scale_factor: Fixed latent scale multiplier if not using std-scaling. + scale_by_std: If True, compute scale as 1/std of latents at first batch. + encoder_type: "2d" (per-frame) or "3d" (volumetric) first-stage behavior. + only_model: If True, load only inner model weights when restoring from ckpt. + noise_strength: Extra offset noise strength for inputs (when > 0). + use_dynamic_rescale: If True, apply time-dependent rescaling array. + base_scale: Target base scale used by dynamic rescaling after turning_step. + turning_step: Steps to transition from 1.0 to base_scale in dynamic rescaling. + interp_mode: Flag for interpolation-specific behaviors (reserved). + fps_condition_type: Frame-per-second conditioning mode label. + perframe_ae: If True, encode/decode one frame at a time. + logdir: Optional directory for logs. + rand_cond_frame: If True, randomly select conditioning frames. + en_and_decode_n_samples_a_time: Optional per-step batch size for (en|de)code loops. + """ + + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # For backwards compatibility after implementation of DiffusionWrapper + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + conditioning_key = default(conditioning_key, 'crossattn') + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.noise_strength = noise_strength + self.use_dynamic_rescale = use_dynamic_rescale + self.interp_mode = interp_mode + self.fps_condition_type = fps_condition_type + self.perframe_ae = perframe_ae + + self.logdir = logdir + self.rand_cond_frame = rand_cond_frame + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + + try: + self.num_downs = len( + first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + + if use_dynamic_rescale: + scale_arr1 = np.linspace(1.0, base_scale, turning_step) + scale_arr2 = np.full(self.num_timesteps, base_scale) + scale_arr = np.concatenate((scale_arr1, scale_arr2)) + to_torch = partial(torch.tensor, dtype=torch.float32) + self.register_buffer('scale_arr', to_torch(scale_arr)) + + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.first_stage_config = first_stage_config + self.cond_stage_config = cond_stage_config + self.clip_denoised = False + + self.cond_stage_forward = cond_stage_forward + self.encoder_type = encoder_type + assert (encoder_type in ["2d", "3d"]) + self.uncond_prob = uncond_prob + self.classifier_free_guidance = True if uncond_prob > 0 else False + assert (uncond_type in ["zero_embed", "empty_seq"]) + self.uncond_type = uncond_type + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model) + self.restarted_from_ckpt = True + + def make_cond_schedule(self) -> None: + """ + Build the condition timestep schedule. + """ + self.cond_ids = torch.full(size=(self.num_timesteps, ), + fill_value=self.num_timesteps - 1, + dtype=torch.long) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, + self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, + batch: Mapping[str, Any], + batch_idx: int, + dataloader_idx: int | None = None) -> None: + """ + Args: + batch: Current batch mapping. + batch_idx: Index of the batch within the epoch. + dataloader_idx: Optional dataloader index in multi-loader setups. + """ + # Only for very first batch, reset the self.scale_factor + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and \ + not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + mainlogger.info("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + mainlogger.info( + f"setting self.scale_factor to {self.scale_factor}") + mainlogger.info("### USING STD-RESCALING ###") + mainlogger.info(f"std={z.flatten().std()}") + + def register_schedule(self, + given_betas: np.ndarray | None = None, + beta_schedule: str = "linear", + timesteps: int = 1000, + linear_start: float = 1e-4, + linear_end: float = 2e-2, + cosine_s: float = 8e-3) -> None: + """ + Extend base schedule registration and optionally shorten conditioning schedule. + + Args: + given_betas: Optional precomputed beta schedule. + beta_schedule: Name of schedule function ("linear", "cosine", etc.). + timesteps: Number of diffusion steps. + linear_start: Linear schedule start beta (if used). + linear_end: Linear schedule end beta (if used). + cosine_s: Cosine schedule parameter (if used). + """ + super().register_schedule(given_betas, beta_schedule, timesteps, + linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config: OmegaConf) -> None: + """ + Build and freeze the first-stage (autoencoder) model. + + Args: + config: OmegaConf config describing the first-stage model to instantiate. + """ + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config: OmegaConf) -> None: + """ + Build the conditioning stage model. + + Args: + config: OmegaConf config describing the conditioning model to instantiate. + + """ + if not self.cond_stage_trainable: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + model = instantiate_from_config(config) + self.cond_stage_model = model + + def get_learned_conditioning(self, c: Any) -> Tensor: + """ + Encode conditioning input into an embedding tensor. + + Args: + c: Raw conditioning input (tensor, list/dict of strings, tokens, etc.). + + Returns: + Conditioning embedding as a tensor (shape depends on cond model). + """ + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable( + self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def get_first_stage_encoding( + self, + encoder_posterior: DiagonalGaussianDistribution | Tensor, + noise: Tensor | None = None) -> Tensor: + """ + Convert encoder posterior to latent z and apply scaling. + + Args: + encoder_posterior: First-stage output; either a Gaussian posterior or a latent tensor. + noise: Optional noise for sampling if posterior is Gaussian. + + Returns: + Scaled latent tensor z. + """ + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample(noise=noise) + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + @torch.no_grad() + def encode_first_stage(self, x: Tensor) -> Tensor: + """ + Encode input frames/images into latent space. + + Args: + x: Input tensor, either (B, C, ...). + + Returns: + Latent tensor with shape matched to input. + """ + if self.encoder_type == "2d" and x.dim() == 5: + b, _, t, _, _ = x.shape + x = rearrange(x, 'b c t h w -> (b t) c h w') + reshape_back = True + else: + reshape_back = False + + ## Consume more GPU memory but faster + if not self.perframe_ae: + encoder_posterior = self.first_stage_model.encode(x) + results = self.get_first_stage_encoding(encoder_posterior).detach() + else: ## Consume less GPU memory but slower + results = [] + for index in range(x.shape[0]): + frame_batch = self.first_stage_model.encode(x[index:index + + 1, :, :, :]) + frame_result = self.get_first_stage_encoding( + frame_batch).detach() + results.append(frame_result) + results = torch.cat(results, dim=0) + + if reshape_back: + results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t) + + return results + + def decode_core(self, z: Tensor, **kwargs: Any) -> Tensor: + """ + Decode latent z back to pixel space (2D or per-frame). + + Args: + z: Latent tensor (B, C, ...). + + Returns: + Decoded tensor in pixel space with shape matching the input layout. + """ + if self.encoder_type == "2d" and z.dim() == 5: + b, _, t, _, _ = z.shape + z = rearrange(z, 'b c t h w -> (b t) c h w') + reshape_back = True + else: + reshape_back = False + + if not self.perframe_ae: + z = 1. / self.scale_factor * z + results = self.first_stage_model.decode(z, **kwargs) + else: + results = [] + for index in range(z.shape[0]): + frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :] + frame_result = self.first_stage_model.decode(frame_z, **kwargs) + results.append(frame_result) + results = torch.cat(results, dim=0) + + if reshape_back: + results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t) + return results + + @torch.no_grad() + def decode_first_stage(self, z: Tensor, **kwargs: Any) -> Tensor: + """ + Decode latent with no gradient. + + Args: + z: Latent tensor to decode. + **kwargs: Extra args for the decoder. + + Returns: + Decoded tensor in pixel space. + + """ + return self.decode_core(z, **kwargs) + + # Same as above but without decorator + def differentiable_decode_first_stage(self, z: Tensor, + **kwargs: Any) -> Tensor: + """ + Decode latent with gradients enabled. + + Args: + z: Latent tensor to decode. + Returns: + ecoded tensor in pixel space. + + """ + return self.decode_core(z, **kwargs) + + @torch.no_grad() + def get_batch_input(self, + batch: Mapping[str, Any], + random_uncond: bool, + return_first_stage_outputs: bool = False, + return_original_cond: bool = False) -> list[Any]: + """ + Prepare batch: encode inputs to latents and produce conditioning embeddings. + + Args: + batch: Batch mapping containing first-stage inputs and conditioning. + random_uncond: If True and `uncond_type` allows, randomly drop conditions. + return_first_stage_outputs: If True, also decode z to xrec for logging. + return_original_cond: If True, also return the raw condition object. + """ + x = super().get_input(batch, self.first_stage_key) + + # Encode video frames x to z via a 2D encoder + z = self.encode_first_stage(x) + + # Get instruction condition + cond = batch[self.cond_stage_key] + if random_uncond and self.uncond_type == 'empty_seq': + for i, ci in enumerate(cond): + if random.random() < self.uncond_prob: + cond[i] = "" + if isinstance(cond, dict) or isinstance(cond, list): + cond_emb = self.get_learned_conditioning(cond) + else: + cond_emb = self.get_learned_conditioning(cond.to(self.device)) + if random_uncond and self.uncond_type == 'zero_embed': + for i, ci in enumerate(cond): + if random.random() < self.uncond_prob: + cond_emb[i] = torch.zeros_like(cond_emb[i]) + + out = [z, cond_emb] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([xrec]) + + if return_original_cond: + out.append(cond) + + return out + + def forward( + self, + x: Tensor, + x_action: Tensor, + x_state: Tensor, + c: Any, + **kwargs: Any, + ) -> tuple[Tensor, dict[str, Tensor]]: + """ + Args: + x: Input latent (or pixel) tensor for the primary stream. + x_action: Action tensor associated with the batch. + x_state: State tensor associated with the batch. + c: Conditioning object (tensor/list/dict) consumed by `apply_model`. + + Returns: + (loss, loss_dict) tuple. + """ + + t = torch.randint(0, + self.num_timesteps, (x.shape[0], ), + device=self.device).long() + if self.use_dynamic_rescale: + x = x * extract_into_tensor(self.scale_arr, t, x.shape) + + return self.p_losses(x, x_action, x_state, c, t, **kwargs) + + def shared_step(self, batch: Mapping[str, Any], random_uncond: bool, + **kwargs: Any) -> tuple[Tensor, dict[str, Tensor]]: + """ + Common train/val step: build inputs, run forward, return loss/logs. + + Args: + batch: Input batch mapping. + random_uncond: Whether to apply classifier-free guidance dropout to cond. + **kwargs: Extra args forwarded to `forward`. + + Returns: + (loss, loss_dict) tuple. + """ + x, c = self.get_batch_input(batch, random_uncond=random_uncond) + loss, loss_dict = self(x, c, **kwargs) + + return loss, loss_dict + + def apply_model(self, x_noisy: Tensor, x_action_noisy: Tensor, + x_state_noisy: Tensor, t: Tensor, cond: Any, + **kwargs: Any) -> Tensor | tuple[Tensor, Tensor, Tensor]: + """ + Apply inner diffusion model with standardized conditioning dict. + + Args: + x_noisy: Noisy latent input for the primary stream. + x_action_noisy: Noisy action tensor aligned with t. + x_state_noisy: Noisy state tensor aligned with t. + t: Timestep indices (B,). + cond: Raw conditioning; will be wrapped into the proper key if not a dict. + **kwargs: Extra args forwarded to the inner model call. + + Returns: + Either a single tensor or a tuple of tensors (x, action, state) depending on model. + """ + if isinstance(cond, dict): + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + x_recon, x_action_recon, x_state_recon = self.model( + x_noisy, x_action_noisy, x_state_noisy, t, **cond, **kwargs) + + if isinstance(x_recon, tuple): + return x_recon[0] + else: + return x_recon, x_action_recon, x_state_recon + + def p_losses( + self, + x_start: Tensor, + x_action_start: Tensor, + x_state_start: Tensor, + cond: Any, + t: Tensor, + noise: Tensor | None = None, + action_noise: Tensor | None = None, + **kwargs: Any, + ) -> tuple[Tensor, dict[str, Tensor]]: + """ + Compute the per-step training losses for latent diffusion. + + Args: + x_start: Clean primary latent (or pixel) tensor. + x_action_start: Clean action tensor aligned with x_start. + x_state_start: Clean state tensor aligned with x_start. + cond: Conditioning object; may include masks for action/state losses. + t: Timestep indices (B,). + noise: Optional epsilon noise for the primary stream (else sampled). + action_noise: Optional epsilon noise for the action stream (else sampled). + **kwargs: Extra args forwarded into `apply_model` (and logged if needed). + + Returns: + (loss, loss_dict) + """ + if self.noise_strength > 0: + b, c, f, _, _ = x_start.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=x_start.device) + noise = default( + noise, lambda: torch.randn_like(x_start) + self.noise_strength + * offset_noise) + else: + noise = default(noise, lambda: torch.randn_like(x_start)) + action_noise = torch.randn(x_action_start.shape, + device=x_action_start.device) + action_noise_new = action_noise + self.input_pertub * torch.randn( + x_action_start.shape, device=x_action_start.device) + + state_noise = torch.randn(x_state_start.shape, + device=x_state_start.device) + state_noise_new = state_noise + self.input_pertub * torch.randn( + x_state_start.shape, device=x_state_start.device) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_action_noisy = self.dp_noise_scheduler_action.add_noise( + x_action_start, action_noise_new, t[:x_action_start.shape[0]]) + x_state_noisy = self.dp_noise_scheduler_state.add_noise( + x_state_start, state_noise_new, t[:x_state_start.shape[0]]) + + kwargs['x_start'] = x_start + model_output, model_action_output, model_state_output = self.apply_model( + x_noisy, x_action_noisy, x_state_noisy, t, cond, **kwargs) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + target_action = action_noise + target_state = state_noise + + loss_simple = self.get_loss(model_output, target, + mean=False).mean([1, 2, 3, 4]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + loss_action_simple = F.mse_loss(model_action_output, + target_action, + reduction='none') + action_mask = cond['c_crossattn_action'][-1] + loss_action_simple *= action_mask + loss_action_simple = loss_action_simple.type(loss_action_simple.dtype) + loss_action_simple = reduce(loss_action_simple, 'b ... -> b (...)', + 'mean') + loss_action_simple = loss_action_simple.sum() / action_mask.sum() + loss_dict.update({f'{prefix}/loss_action_simple': loss_action_simple}) + + loss_state_simple = F.mse_loss(model_state_output, + target_state, + reduction='none') + state_mask = cond['c_crossattn_action'][-2] + loss_state_simple *= state_mask + loss_state_simple = loss_state_simple.type(loss_state_simple.dtype) + loss_state_simple = reduce(loss_state_simple, 'b ... -> b (...)', + 'mean') + loss_state_simple = loss_state_simple.sum() / state_mask.sum() + loss_dict.update({f'{prefix}/loss_state_simple': loss_state_simple}) + + if self.logvar.device is not self.device: + self.logvar = self.logvar.to(self.device) + logvar_t = self.logvar[t] + loss = loss_simple / torch.exp(logvar_t) + logvar_t + loss_action = loss_action_simple + loss_state = loss_state_simple + + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + loss_vlb = self.get_loss(model_output, target, + mean=False).mean(dim=(1, 2, 3, 4)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + + loss_dict.update({f'{prefix}/loss_action_vlb': loss_action}) + loss_dict.update({f'{prefix}/loss_state_vlb': loss_state}) + + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + loss_dict.update({f'{prefix}/loss_action': loss_action}) + loss_dict.update({f'{prefix}/loss_state': loss_state}) + + if cond['c_crossattn_action'][2]: + return loss + loss_state + loss_action * 0.0, loss_dict + else: + return loss + loss_action + loss_state * 0.0, loss_dict + + def training_step(self, batch: Mapping[str, Any], + batch_idx: int) -> Tensor: + """ + Lightning training step: compute loss and log metrics. + + Args: + batch: Training batch mapping. + batch_idx: Batch index within current epoch. + + Returns: + Scalar loss tensor for optimization. + """ + loss, loss_dict = self.shared_step( + batch, random_uncond=self.classifier_free_guidance) + loss_dict.update( + {'lr': self.trainer.optimizers[0].param_groups[0]['lr']}) + loss_dict.update({ + 'lr_action_unet': + self.trainer.optimizers[0].param_groups[1]['lr'] + }) + # Sync_dist | rank_zero_only + self.log_dict(loss_dict, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + sync_dist=False) + if (batch_idx + 1) % self.log_every_t == 0: + mainlogger.info( + f"batch:{batch_idx}|epoch:{self.current_epoch} [globalstep:{self.global_step}]: loss={loss}" + ) + return loss + + @torch.no_grad() + def validation_step(self, batch: Mapping[str, Any], + batch_idx: int) -> None: + """ + Lightning validation step: compute loss and log metrics. + + Args: + batch: Validation batch mapping. + batch_idx: Batch index in validation loop. + """ + _, loss_dict_no_ema = self.shared_step(batch, random_uncond=False) + self.log_dict(loss_dict_no_ema, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True) + + def _get_denoise_row_from_list(self, + samples: Sequence[Tensor], + desc: str = '') -> Tensor: + """ + Decode a list of latents and pack into a grid for visualization. + + Args: + samples: Sequence of latent tensors to decode and tile. + desc: Optional tqdm description string. + + Returns: + Grid image tensor suitable for logging. + + """ + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device))) + n_log_timesteps = len(denoise_row) + + denoise_row = torch.stack(denoise_row) + + if denoise_row.dim() == 5: + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_log_timesteps) + elif denoise_row.dim() == 6: + video_length = denoise_row.shape[3] + denoise_grid = rearrange(denoise_row, 'n b c t h w -> b n c t h w') + denoise_grid = rearrange(denoise_grid, + 'b n c t h w -> (b n) c t h w') + denoise_grid = rearrange(denoise_grid, 'n c t h w -> (n t) c h w') + denoise_grid = make_grid(denoise_grid, nrow=video_length) + else: + raise ValueError + + return denoise_grid + + @torch.no_grad() + def log_images(self, + batch: Mapping[str, Any], + sample: bool = True, + ddim_steps: int = 200, + ddim_eta: float = 1.0, + plot_denoise_rows: bool = False, + unconditional_guidance_scale: float = 1.0, + **kwargs: Any) -> dict[str, Tensor]: + """ Log images for LatentDiffusion """ + # Control sampled imgae for logging, larger value may cause OOM + sampled_img_num = 2 + for key in batch.keys(): + batch[key] = batch[key][:sampled_img_num] + + # TBD: currently, classifier_free_guidance sampling is only supported by DDIM + use_ddim = ddim_steps is not None + log = dict() + z, c, xrec, xc = self.get_batch_input(batch, + random_uncond=False, + return_first_stage_outputs=True, + return_original_cond=True, + logging=True) + + N = xrec.shape[0] + log["reconst"] = xrec + log["condition"] = xc + + if sample: + uc = None + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + x0=z, + **kwargs) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + return log + + def p_mean_variance( + self, + x: Tensor, + c: Any, + t: Tensor, + clip_denoised: bool, + return_x0: bool = False, + score_corrector: Any = None, + corrector_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any + ) -> tuple[Tensor, Tensor, Tensor] | tuple[Tensor, Tensor, Tensor, Tensor]: + """ + Predict posterior parameters (and optionally x0) at timestep t. + + Args: + x: Current latent at timestep t. + c: Conditioning object passed to the inner model/score corrector. + t: Timestep indices (B,). + clip_denoised: If True, clamp predicted x0 to [-1, 1]. + return_x0: If True, also return predicted x0. + score_corrector: Optional score-corrector object with `modify_score`. + corrector_kwargs: Extra kwargs for the score corrector. + **kwargs: Forwarded to `apply_model`. + + Returns: + (mean, var, log_var) or (mean, var, log_var, x0) tensors. + """ + + t_in = t + model_out = self.apply_model(x, t_in, c, **kwargs) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, + **corrector_kwargs) + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + + if return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, + x: Tensor, + c: Any, + t: Tensor, + clip_denoised: bool = False, + repeat_noise: bool = False, + return_x0: bool = False, + temperature: float = 1.0, + noise_dropout: float = 0.0, + score_corrector: Any = None, + corrector_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any) -> Tensor | tuple[Tensor, Tensor]: + """ + Draw a single reverse-diffusion step (optionally return x0). + + Args: + x: Current latent at timestep t. + c: Conditioning object for the model. + t: Timestep indices (B,). + clip_denoised: Clamp predicted x0 to [-1, 1] when forming the mean. + repeat_noise: If True, reuse the same noise across batch. + return_x0: If True, also return the predicted x0. + temperature: Temperature for sampling noise scale. + noise_dropout: Dropout probability applied to the sampled noise. + score_corrector: Optional score-corrector to adjust model outputs. + corrector_kwargs: Extra kwargs for the corrector. + **kwargs: Forwarded to `p_mean_variance`. + + Returns: + Next latent (and optionally x0). + """ + + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, return_x0=return_x0, \ + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, **kwargs) + if return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # No noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape( + b, *((1, ) * (len(x.shape) - 1))) + + if return_x0: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, + cond: Any, + shape: Sequence[int], + return_intermediates: bool = False, + x_T: Tensor | None = None, + verbose: bool = True, + callback: Callable[[int], Any] | None = None, + timesteps: int | None = None, + mask: Tensor | None = None, + x0: Tensor | None = None, + img_callback: Callable[[Tensor, int], Any] | None = None, + start_T: int | None = None, + log_every_t: int | None = None, + **kwargs: Any) -> Tensor | tuple[Tensor, list[Tensor]]: + """ + Run the full reverse process from noise to sample(s). + + Args: + cond: Conditioning object (tensor/dict/list), optionally noised when cond schedule is shortened. + shape: Output latent shape (B, C, ...). + return_intermediates: If True, also return intermediate latents. + x_T: Optional starting noise latent (else sampled from N(0, I)). + verbose: If True, show tqdm progress. + callback: Optional function called with the current timestep i. + timesteps: Number of reverse steps to perform (default: self.num_timesteps). + mask: Optional inpainting mask; ones keep original x0 regions. + x0: Optional original latent for masked regions (when using `mask`). + img_callback: Optional function called with (img, i) every step. + start_T: Optional cap to limit starting step (min(timesteps, start_T)). + log_every_t: Logging frequency for collecting intermediates (defaults to self.log_every_t). + + Returns: + Final latent sample (and optionally the list of intermediates). + """ + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + # Sample an initial noise + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + if start_T is not None: + timesteps = min(timesteps, start_T) + + iterator = tqdm( + reversed(range(0, timesteps)), desc='Sampling t', + total=timesteps) if verbose else reversed(range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] + + for i in iterator: + ts = torch.full((b, ), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, + t=tc, + noise=torch.randn_like(cond)) + + img = self.p_sample(img, + cond, + ts, + clip_denoised=self.clip_denoised, + **kwargs) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, + cond, + batch_size: int = 16, + return_intermediates: bool = False, + x_T: Tensor | None = None, + verbose: bool = True, + timesteps: int | None = None, + mask: Tensor | None = None, + x0: Tensor | None = None, + shape: Sequence[int] | None = None, + **kwargs: Any) -> Tensor | tuple[Tensor, list[Tensor]]: + """ + Convenience wrapper to run `p_sample_loop` with a full batch. + + Args: + cond: Conditioning object; dict/list items are truncated to batch_size. + batch_size: Number of samples to generate. + return_intermediates: If True, return intermediates as well. + x_T: Optional starting noise latent (else sampled). + verbose: Whether to print sampling progress. + timesteps: Number of reverse steps (default: self.num_timesteps). + mask: Optional mask for partial generation/inpainting. + x0: Optional original latent used with `mask` during sampling. + shape: Optional output shape; if None, uses (B, C, T, H, W) from model config. + + Returns: + Final latent (and optionally intermediates). + """ + if shape is None: + shape = (batch_size, self.channels, self.temporal_length, + *self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: + cond[key][:batch_size] if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = [c[:batch_size] for c in cond] if isinstance( + cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + mask=mask, + x0=x0, + **kwargs) + + @torch.no_grad() + def sample_log(self, cond: Any, batch_size: int, ddim: bool, + ddim_steps: int, + **kwargs: Any) -> tuple[Any, Any, Any, Any]: + """ + Produce samples (and intermediates), optionally via DDIM sampler. + + Args: + cond: Conditioning object passed to the sampler. + batch_size: Number of samples to generate. + ddim: If True, use DDIM sampler; otherwise use ancestral sampling. + ddim_steps: Number of DDIM steps when `ddim` is True. + + """ + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.temporal_length, *self.image_size) + samples, actions, states, intermediates = ddim_sampler.sample( + ddim_steps, batch_size, shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, + batch_size=batch_size, + return_intermediates=True, + **kwargs) + + return samples, actions, states, intermediates + + def configure_schedulers( + self, optimizer: torch.optim.Optimizer) -> dict[str, Any]: + """ + Build LR scheduler dict compatible with PyTorch Lightning. + + Args: + optimizer: Optimizer instance for which to build the scheduler dict. + + Returns: + Dict with keys {'scheduler', 'interval', 'frequency'} per Lightning API. + """ + assert 'target' in self.scheduler_config + scheduler_name = self.scheduler_config.target.split('.')[-1] + interval = self.scheduler_config.interval + frequency = self.scheduler_config.frequency + if scheduler_name == "LambdaLRScheduler": + scheduler = instantiate_from_config(self.scheduler_config) + scheduler.start_step = self.global_step + lr_scheduler = { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': interval, + 'frequency': frequency + } + elif scheduler_name == "CosineAnnealingLRScheduler": + scheduler = instantiate_from_config(self.scheduler_config) + decay_steps = scheduler.decay_steps + last_step = -1 if self.global_step == 0 else scheduler.start_step + lr_scheduler = { + 'scheduler': + CosineAnnealingLR(optimizer, + T_max=decay_steps, + last_epoch=last_step), + 'interval': + interval, + 'frequency': + frequency + } + else: + raise NotImplementedError + return lr_scheduler + + +class LatentVisualDiffusion(LatentDiffusion): + """ + Visual-conditioned latent diffusion with action/state heads and schedulers. + + """ + + def __init__(self, + img_cond_stage_config: OmegaConf, + image_proj_stage_config: OmegaConf, + noise_scheduler_config: OmegaConf, + dp_optimizer_config: OmegaConf, + dp_ema_config: OmegaConf, + freeze_embedder: bool = True, + image_proj_model_trainable: bool = True, + n_obs_steps_imagen: int = 2, + n_obs_steps_acting: int = 2, + agent_state_dim: int = 14, + agent_action_dim: int = 14, + global_emb_dim: int = 1024, + input_pertub: float = 0.1, + lr_scheduler: str = 'cosine', + lr_warmup_steps: int = 500, + num_epochs: int = 15000, + gradient_accumulate_every: int = 1, + use_scheduler: bool = False, + dp_use_ema: bool = False, + pretrained_checkpoint: str | None = None, + decision_making_only: bool = True, + *args, + **kwargs): + """ + Args: + img_cond_stage_config: OmegaConf for the *image* conditioning encoder. + image_proj_stage_config: OmegaConf for the image feature projector. + noise_scheduler_config: OmegaConf for DP noise schedulers (state/action). + dp_optimizer_config: OmegaConf for optimizer params of the UNet heads. + dp_ema_config: Optional EMA config for the action UNet. + freeze_embedder: If True, freeze the image embedder params. + image_proj_model_trainable: If True, train the image projector. + n_obs_steps_imagen: Number of observed steps for image conditions. + n_obs_steps_acting: Number of observed steps for acting head. + agent_state_dim: Dimension of agent state vector. + agent_action_dim: Dimension of agent action vector. + global_emb_dim: Embedding size for state/action/text/image fusion. + input_pertub: Perturbation scale added to action/state noises. + lr_scheduler: Name of LR scheduler (for SelectiveLRScheduler wrapper). + lr_warmup_steps: Warmup steps for scheduler creation. + num_epochs: Total training epochs. + gradient_accumulate_every: Gradient accumulation steps. + use_scheduler: If True, enable LR scheduling. + dp_use_ema: If True, maintain EMA for action UNet head. + pretrained_checkpoint: Optional path to a pretrained checkpoint. + decision_making_only: If True, use decision-only augmentation path. + """ + + super().__init__(*args, **kwargs) + self.image_proj_model_trainable = image_proj_model_trainable + self.agent_state_dim = agent_state_dim + self.agent_action_dim = agent_action_dim + self.global_emb_dim = global_emb_dim + self.n_obs_steps_imagen = n_obs_steps_imagen + self.n_obs_steps_acting = n_obs_steps_acting + self.decision_making_only = decision_making_only + + self._init_embedder(img_cond_stage_config, freeze_embedder) + self._init_img_ctx_projector(image_proj_stage_config, + image_proj_model_trainable) + self._init_dp_noise_scheduler(noise_scheduler_config) + self._init_projectors() + if dp_use_ema: + self._init_dp_ema(dp_ema_config) + + # Create a pos_embedder for state and action info, our state and action have an unified vector space + self.pos_embedder = SinusoidalPosEmb(self.global_emb_dim) + self.register_buffer('cond_pos_emb', + self.pos_embedder(torch.arange( + 0, 16))) #NOTE HAND-CODE 16 + + self.input_pertub = input_pertub + self.dp_optimizer_config = dp_optimizer_config + self.lr_scheduler = lr_scheduler + self.lr_warmup_steps = lr_warmup_steps + self.num_epochs = num_epochs + self.gradient_accumulate_every = gradient_accumulate_every + self.use_scheduler = use_scheduler + self.dp_use_ema = dp_use_ema + self.pretrained_checkpoint = pretrained_checkpoint + + def _init_img_ctx_projector(self, config: OmegaConf, + trainable: bool) -> None: + """ + Instantiate image context projector; optionally freeze. + + Args: + config: OmegaConf for the projector module to instantiate. + trainable: If False, freeze the projector. + """ + self.image_proj_model = instantiate_from_config(config) + if not trainable: + self.image_proj_model.eval() + self.image_proj_model.train = disabled_train + for param in self.image_proj_model.parameters(): + param.requires_grad = False + + def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None: + """ + Instantiate the image embedder; optionally freeze. + + Args: + config: OmegaConf for the embedder to instantiate. + freeze: If True, set to eval/disable grads. + """ + self.embedder = instantiate_from_config(config) + if freeze: + self.embedder.eval() + self.embedder.train = disabled_train + for param in self.embedder.parameters(): + param.requires_grad = False + + def init_normalizers(self, normalize_config: OmegaConf, + dataset_stats: Mapping[str, Any]) -> None: + """ + Create normalization and unnormalization utilities. + + Args: + normalize_config: Config with shapes and normalization modes. + dataset_stats: Statistics dict used to compute normalization. + """ + self.normalize_inputs = Normalize( + normalize_config.input_shapes, + normalize_config.input_normalization_modes, dataset_stats) + self.unnormalize_outputs = Unnormalize( + normalize_config.output_shapes, + normalize_config.output_normalization_modes, dataset_stats) + + def _init_dp_noise_scheduler(self, config: OmegaConf) -> None: + """ + Instantiate separate DP noise schedulers for action and state. + + Args: + config: OmegaConf used to create scheduler instances. + """ + self.dp_noise_scheduler_action = instantiate_from_config(config) + self.dp_noise_scheduler_state = instantiate_from_config(config) + + def _init_dp_ema(self, config: OmegaConf | None) -> None: + """ + Initialize EMA for UNet head. + + Args: + config: EMA config, must contain 'params' sub-dict. + """ + self.dp_ema_model = copy.deepcopy( + self.model.diffusion_model.action_unet) + self.dp_ema_model_on_device = False + self.dp_ema = EMAModel(**config['params'], model=self.dp_ema_model) + + def _init_projectors(self): + """ + Build small MLP projectors and positional embeddings for state/action. + """ + self.state_projector = MLPProjector(self.agent_state_dim, + 1024) # NOTE HAND CODE + self.action_projector = MLPProjector(self.agent_action_dim, + 1024) # NOTE HAND CODE + + self.agent_action_pos_emb = nn.Parameter( + torch.randn(1, 16, self.global_emb_dim)) + self.agent_state_pos_emb = nn.Parameter( + torch.randn(1, self.n_obs_steps_imagen, self.global_emb_dim)) + + def _get_augmented_batch( + self, + z: Tensor, + state: Tensor, + obs_state: Tensor, + action: Tensor, + ins: Tensor, + null_ins: Tensor, + img: Tensor, + sim_mode: bool = False, + pre_action: Tensor | None = None, + logging: bool = False) -> tuple[Tensor, Tensor, list[Tensor]]: + """ + Construct augmented conditioning batch for decision/simulation modes. + + Args: + z: Latent video tensor (B, C, ...). + state: Full state tensor (B, T, D_s). + obs_state: Observed state embeddings (B, T, E). + action: Action embeddings (B, T, E). + ins: Instruction/text embeddings (B, L, E) after projector. + null_ins: Null/empty instruction embedding for CFG. + img: Image conditioning embedding (B, E_img) or batched equivalent. + sim_mode: If True, build simulated-mode batch; else decision-making. + pre_action: Optional previous action(s). (unused here; reserved) + logging: If True, may include extra returns for logs. (unused) + + Returns: + Tuple of (z, state, [mode_batch]) where mode_batch is a single tensor combining the selected conditioning streams. + """ + + b, _, t, _, _ = z.shape + if self.decision_making_only: + mode_batch = torch.cat([obs_state, ins, img], dim=1) + return z, state, [mode_batch] + + if not sim_mode: + zero_action = torch.zeros_like(action) + mode_batch = torch.cat([obs_state, zero_action, ins, img], dim=1) + else: + null_ins_batch = null_ins.repeat_interleave(repeats=ins.shape[0], + dim=0) + mode_batch = torch.cat([obs_state, action, null_ins_batch, img], + dim=1) + return z, state, [mode_batch] + + def on_train_batch_end(self, outputs: Any, batch: Mapping[str, Any], + batch_idx: int) -> None: + """ + Update EMA for action UNet after each train batch (if enabled). + + Args: + batch: Current training batch mapping. + batch_idx: Batch index within the epoch. + """ + if self.dp_use_ema: + if self.dp_ema_model is not None and not self.dp_ema_model_on_device: + device = self.model.device + self.dp_ema_model.to(device) + self.dp_ema_model_on_device = True + self.dp_ema.step(self.model.diffusion_model.action_unet) + + def shared_step(self, batch: Mapping[str, Any], random_uncond: bool, + **kwargs: Any) -> tuple[Tensor, dict[str, Tensor]]: + """ + Common train/val step for visual diffusion. + + Args: + batch: Input batch mapping. + random_uncond: Whether to apply classifier-free guidance dropout. + + Returns: + (loss, loss_dict) tuple. + """ + x, x_action, x_state, c, fs = self.get_batch_input( + batch, random_uncond=random_uncond, return_fs=True) + kwargs.update({"fs": fs.long()}) + loss, loss_dict = self(x, x_action, x_state, c, **kwargs) + return loss, loss_dict + + def get_batch_input(self, + batch: Mapping[str, Any], + random_uncond: bool, + return_first_stage_outputs: bool = False, + return_original_cond: bool = False, + return_fs: bool = False, + return_cond_frame: bool = False, + return_original_input: bool = False, + logging: bool = False, + **kwargs: Any) -> list[Any]: + """ + Prepare model inputs & conditioning from a raw training batch. + + Args: + batch: Batch mapping with keys like image/state/action/obs/etc. + random_uncond: Apply stochastic condition dropout for CFG. + return_first_stage_outputs: If True, also return xrec (decoded z). + return_original_cond: If True, also return raw instruction text. + return_fs: If True, return fps or frame_stride per config. + return_cond_frame: If True, return conditioning frames (obs images). + return_original_input: If True, return original x (pre-encoding). + logging: If True, append sim_mode flag at the end. + + Returns: + A list of inputs + """ + # x: b c t h w + x = super().get_input(batch, self.first_stage_key) + b, _, t, _, _ = x.shape + # Get actions: b t d + action = super().get_input(batch, 'action') + # Get states: b t d + state = super().get_input(batch, 'next.state') + # Get observable states: b t d + obs_state = super().get_input(batch, 'observation.state') + # Get observable images: b c t h w + obs = super().get_input(batch, 'observation.image') + + # Encode video frames x to z via a 2D encoder + z = self.encode_first_stage(x) + + cond = {} + # Get instruction condition + cond_ins_input = batch[self.cond_stage_key] + if isinstance(cond_ins_input, dict) or isinstance( + cond_ins_input, list): + cond_ins_emb = self.get_learned_conditioning(cond_ins_input) + else: + cond_ins_emb = self.get_learned_conditioning( + cond_ins_input.to(self.device)) + # To support classifier-free guidance, randomly drop out only text conditioning + # 5%, only image conditioning 5%, and both 5%. + if random_uncond: + random_num = torch.rand(b, device=x.device) + else: + random_num = torch.ones(b, device=x.device) + prompt_mask = rearrange(random_num < 2 * self.uncond_prob, + "n -> n 1 1") + null_prompt = self.get_learned_conditioning([""]) + cond_ins_emb = torch.where(prompt_mask, null_prompt, + cond_ins_emb.detach()) + + # Get conditioning frames + cond_frame_index = 0 + img = obs[:, :, -1, ...] + input_mask = 1 - rearrange( + (random_num >= self.uncond_prob).float() * + (random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1") + + cond_img = input_mask * img + cond_img_emb = self.embedder(cond_img) + cond_img_emb = self.image_proj_model(cond_img_emb) + + if self.model.conditioning_key == 'hybrid': + if self.interp_mode: + img_cat_cond = torch.zeros_like(z) + img_cat_cond[:, :, 0, :, :] = z[:, :, 0, :, :] + img_cat_cond[:, :, -1, :, :] = z[:, :, -1, :, :] + else: + img_cat_cond = z[:, :, cond_frame_index, :, :] + img_cat_cond = img_cat_cond.unsqueeze(2) + img_cat_cond = repeat(img_cat_cond, + 'b c t h w -> b c (repeat t) h w', + repeat=z.shape[2]) + cond["c_concat"] = [img_cat_cond] + + cond_action = self.action_projector(action) + cond_action_emb = self.agent_action_pos_emb + cond_action + # Get conditioning states + cond_state = self.state_projector(obs_state) + cond_state_emb = self.agent_state_pos_emb + cond_state + + if self.decision_making_only: + is_sim_mode = False + else: + is_sim_mode = torch.rand(1) < 0.5 + z, state, cond["c_crossattn"] = self._get_augmented_batch( + z, + state, + cond_state_emb, + cond_action_emb, + cond_ins_emb, + null_prompt, + cond_img_emb, + sim_mode=is_sim_mode, + logging=logging) + + cond["c_crossattn_action"] = [ + obs[:, :, -self.n_obs_steps_acting:], + state[:, -self.n_obs_steps_acting:], is_sim_mode, + batch['state_mask'], batch['action_mask'] + ] + + out = [z, action, state, cond] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([xrec]) + if return_original_cond: + out.append(cond_ins_input) + if return_fs: + if self.fps_condition_type == 'fs': + fs = super().get_input(batch, 'frame_stride') + elif self.fps_condition_type == 'fps': + fs = super().get_input(batch, 'fps') + out.append(fs) + if return_cond_frame: + out.append(obs) + if return_original_input: + out.append(x) + if logging: + out.append(is_sim_mode) + return out + + @torch.no_grad() + def log_images(self, + batch: Mapping[str, Any], + sample: bool = True, + ddim_steps: int = 50, + ddim_eta: float = 1.0, + plot_denoise_rows: bool = False, + unconditional_guidance_scale: float = 1.0, + mask: Tensor | None = None, + **kwargs) -> dict[str, Tensor]: + """ + Log images for LatentVisualDiffusion + + Args: + batch: Batch mapping used to form inputs/conditions. + sample: If True, also run sampling for visualization. + ddim_steps: Number of DDIM steps when using DDIM. + ddim_eta: DDIM eta parameter (stochasticity). + plot_denoise_rows: If True, include denoise progression grid. + unconditional_guidance_scale: Guidance scale for CFG sampling. + mask: Optional mask for sampling-time inpainting. + + Returns: + Dict of visualization tensors (images/actions/states/progress). + """ + + ##### sampled_img_num: control sampled imgae for logging, larger value may cause OOM + sampled_img_num = 1 + for key in batch.keys(): + batch[key] = batch[key][:sampled_img_num] + + ## TBD: currently, classifier_free_guidance sampling is only supported by DDIM + use_ddim = ddim_steps is not None + log = dict() + + z, act, state, c, xrec, xc, fs, cond_x, is_sim_mode = self.get_batch_input( + batch, + random_uncond=False, + return_first_stage_outputs=True, + return_original_cond=True, + return_fs=True, + return_cond_frame=True, + logging=True) + + kwargs['x_start'] = z + + N = xrec.shape[0] + log["image_condition"] = cond_x + log["reconst"] = xrec + if is_sim_mode: + xc = ["NULL"] + xc_with_fs = [] + for idx, content in enumerate(xc): + xc_with_fs.append(content + '_fs=' + str(fs[idx].item())) + log['instruction'] = xc + log["condition"] = xc_with_fs + kwargs.update({"fs": fs.long()}) + + if sample: + uc = None + with self.ema_scope("Plotting"): + samples, action_samples, state_samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + x0=z, + **kwargs) + + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + # Log actions + mb, mt, _ = batch['action_mask'].shape + act_mask = batch['action_mask'] == 1.0 + action_target = act[act_mask].reshape(mb, mt, -1) + action_samples = action_samples[act_mask].reshape(mb, mt, -1) + log["action"] = torch.cat((action_target, action_samples), dim=0) + + # Log states + mb, mt, _ = batch['state_mask'].shape + state_mask = batch['state_mask'] == 1.0 + state_target = state[state_mask].reshape(mb, mt, -1) + state_samples = state_samples[state_mask].reshape(mb, mt, -1) + log["state"] = torch.cat((state_target, state_samples), dim=0) + + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + log["video_idx"] = batch["path"][0].split('/')[-1][:-4] + return log + + def configure_optimizers(self): + """ configure_optimizers for LatentDiffusion """ + lr = self.learning_rate + + params = [ + param for name, param in self.model.named_parameters() + if not name.startswith("diffusion_model.action_unet") + and not name.startswith("diffusion_model.state_unet") + ] + params_unet_head = list( + self.model.diffusion_model.action_unet.parameters()) + list( + self.model.diffusion_model.state_unet.parameters()) + + mainlogger.info(f"@Training [{len(params)}] Full Paramters.") + + if self.cond_stage_trainable: + params_cond_stage = [ + p for p in self.cond_stage_model.parameters() + if p.requires_grad == True + ] + mainlogger.info( + f"@Training [{len(params_cond_stage)}] Paramters for Cond_stage_model." + ) + params.extend(params_cond_stage) + + if self.image_proj_model_trainable: + mainlogger.info( + f"@Training [{len(list(self.image_proj_model.parameters()))}] Paramters for Image_proj_model." + ) + params.extend(list(self.image_proj_model.parameters())) + + if self.learn_logvar: + mainlogger.info('Diffusion model optimizing logvar') + if isinstance(params[0], dict): + params.append({"params": [self.logvar]}) + else: + params.append(self.logvar) + + params_group = [{ + 'params': params, + 'lr': lr + }, { + 'params': + params_unet_head, + 'lr': + self.dp_optimizer_config['params']['lr'], + 'betas': + self.dp_optimizer_config['params']['betas'], + 'eps': + self.dp_optimizer_config['params']['eps'], + 'weight_decay': + self.dp_optimizer_config['params']['weight_decay'] + }] + optimizer = torch.optim.AdamW(params_group, lr=lr) + + if self.use_scheduler: + + # mainlogger.info("Setting up scheduler...") + lr_scheduler = get_scheduler( + self.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=self.lr_warmup_steps, + num_training_steps=(self.datasets_len * self.num_epochs) // + self.gradient_accumulate_every, # 50 is handcode + last_epoch=-1) + + scheduler = SelectiveLRScheduler( + optimizer=optimizer, + base_scheduler=lr_scheduler, + group_indices=[1], + default_lr=[lr, self.dp_optimizer_config['params']['lr']]) + return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] + + return [optimizer] + + +class DiffusionWrapper(pl.LightningModule): + """Thin wrapper that routes inputs/conditions to the underlying diffusion model.""" + + def __init__(self, diff_model_config: OmegaConf, + conditioning_key: str | None) -> None: + """ + Args: + diff_model_config: OmegaConf describing the inner diffusion model to instantiate. + conditioning_key: How conditioning is applied. + """ + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + + def forward( + self, + x: Tensor, + x_action: Tensor | None, + x_state: Tensor | None, + t: Tensor, + c_concat: Sequence[Tensor] | None = None, + c_crossattn: Sequence[Tensor] | None = None, + c_crossattn_action: list[Any] | None = None, + c_adm: Tensor | None = None, + s: Tensor | None = None, + mask: Tensor | None = None, + **kwargs: Any, + ) -> Any: + """ + Route input(s) and condition(s) into the inner diffusion model based on `conditioning_key`. + + Args: + x: Primary input tensor (e.g., latent/image) at timestep `t`. + x_action: Action stream tensor (used by 'hybrid' variants). + x_state: State stream tensor (used by 'hybrid' variants). + t: Timestep indices (B,). + c_concat: List of tensors to be concatenated channel-wise with `x` (for 'concat' / 'hybrid' modes). + c_crossattn: List of context tensors concatenated along sequence/channel dim for cross-attention. + c_crossattn_action: Mixed list used by action/state heads. + c_adm: Class/ADM conditioning (e.g., labels) when required by '*adm*' modes. + s: Optional additional time-like / scalar conditioning (e.g., fps/frame-stride) for '*time*' modes. + mask: Optional spatial/temporal mask (e.g., inpainting) for '*mask*' modes. + **kwargs: Extra keyword arguments forwarded to the inner diffusion model. + + Returns: + Output from the inner diffusion model (tensor or tuple, depending on the model). + """ + + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t, **kwargs) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, **kwargs) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + cc_action = c_crossattn_action + out = self.diffusion_model(xc, + x_action, + x_state, + t, + context=cc, + context_action=cc_action, + **kwargs) + elif self.conditioning_key == 'resblockcond': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm, **kwargs) + elif self.conditioning_key == 'hybrid-time': + assert s is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, s=s) + elif self.conditioning_key == 'concat-time-mask': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t, context=None, s=s, mask=mask) + elif self.conditioning_key == 'concat-adm-mask': + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + out = self.diffusion_model(xc, t, context=None, y=s, mask=mask) + elif self.conditioning_key == 'hybrid-adm-mask': + cc = torch.cat(c_crossattn, 1) + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + out = self.diffusion_model(xc, t, context=cc, y=s, mask=mask) + elif self.conditioning_key == 'hybrid-time-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, s=s, y=c_adm) + elif self.conditioning_key == 'crossattn-adm': + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + else: + raise NotImplementedError() + + return out diff --git a/src/unifolm_wma/models/diffusion_head/__init__.py b/src/unifolm_wma/models/diffusion_head/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/unifolm_wma/models/diffusion_head/base_nets.py b/src/unifolm_wma/models/diffusion_head/base_nets.py new file mode 100644 index 0000000..8e96baf --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/base_nets.py @@ -0,0 +1,217 @@ +""" +Contains torch Modules that correspond to basic network building blocks, like +MLP, RNN, and CNN backbones. +""" + +import abc +import numpy as np +import torch +import torch.nn.functional as F + + +class Module(torch.nn.Module): + """ + Base class for networks. The only difference from torch.nn.Module is that it + requires implementing @output_shape. + """ + + @abc.abstractmethod + def output_shape(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + raise NotImplementedError + + +""" +================================================ +Visual Backbone Networks +================================================ +""" + + +class ConvBase(Module): + """ + Base class for ConvNets. + """ + + def __init__(self): + super(ConvBase, self).__init__() + + # dirty hack - re-implement to pass the buck onto subclasses from ABC parent + def output_shape(self, input_shape): + """ + Function to compute output shape from inputs to this module. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + raise NotImplementedError + + def forward(self, inputs): + x = self.nets(inputs) + if list(self.output_shape(list(inputs.shape)[1:])) != list( + x.shape)[1:]: + raise ValueError('Size mismatch: expect size %s, but got size %s' % + (str(self.output_shape(list( + inputs.shape)[1:])), str(list(x.shape)[1:]))) + return x + + +class SpatialSoftmax(ConvBase): + """ + Spatial Softmax Layer. + + Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al. + https://rll.berkeley.edu/dsae/dsae.pdf + """ + + def __init__( + self, + input_shape, + num_kp=32, + temperature=1., + learnable_temperature=False, + output_variance=False, + noise_std=0.0, + ): + """ + Args: + input_shape (list): shape of the input feature (C, H, W) + num_kp (int): number of keypoints (None for not using spatialsoftmax) + temperature (float): temperature term for the softmax. + learnable_temperature (bool): whether to learn the temperature + output_variance (bool): treat attention as a distribution, and compute second-order statistics to return + noise_std (float): add random spatial noise to the predicted keypoints + """ + super(SpatialSoftmax, self).__init__() + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape # (C, H, W) + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._num_kp = num_kp + else: + self.nets = None + self._num_kp = self._in_c + self.learnable_temperature = learnable_temperature + self.output_variance = output_variance + self.noise_std = noise_std + + if self.learnable_temperature: + # temperature will be learned + temperature = torch.nn.Parameter(torch.ones(1) * temperature, + requires_grad=True) + self.register_parameter('temperature', temperature) + else: + # temperature held constant after initialization + temperature = torch.nn.Parameter(torch.ones(1) * temperature, + requires_grad=False) + self.register_buffer('temperature', temperature) + + pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., self._in_w), + np.linspace(-1., 1., self._in_h)) + pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h * + self._in_w)).float() + pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h * + self._in_w)).float() + self.register_buffer('pos_x', pos_x) + self.register_buffer('pos_y', pos_y) + + self.kps = None + + def __repr__(self): + """Pretty print network.""" + header = format(str(self.__class__.__name__)) + return header + '(num_kp={}, temperature={}, noise={})'.format( + self._num_kp, self.temperature.item(), self.noise_std) + + def output_shape(self, input_shape): + """ + Function to compute output shape from inputs to this module. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + assert (len(input_shape) == 3) + assert (input_shape[0] == self._in_c) + return [self._num_kp, 2] + + def forward(self, feature): + """ + Forward pass through spatial softmax layer. For each keypoint, a 2D spatial + probability distribution is created using a softmax, where the support is the + pixel locations. This distribution is used to compute the expected value of + the pixel location, which becomes a keypoint of dimension 2. K such keypoints + are created. + + Returns: + out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly + keypoint variance of shape [B, K, 2, 2] corresponding to the covariance + under the 2D spatial softmax distribution + """ + assert (feature.shape[1] == self._in_c) + assert (feature.shape[2] == self._in_h) + assert (feature.shape[3] == self._in_w) + if self.nets is not None: + feature = self.nets(feature) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + feature = feature.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(feature / self.temperature, dim=-1) + # [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions + expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True) + expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True) + # stack to [B * K, 2] + expected_xy = torch.cat([expected_x, expected_y], 1) + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._num_kp, 2) + + if self.training: + noise = torch.randn_like(feature_keypoints) * self.noise_std + feature_keypoints += noise + + if self.output_variance: + # treat attention as a distribution, and compute second-order statistics to return + expected_xx = torch.sum(self.pos_x * self.pos_x * attention, + dim=1, + keepdim=True) + expected_yy = torch.sum(self.pos_y * self.pos_y * attention, + dim=1, + keepdim=True) + expected_xy = torch.sum(self.pos_x * self.pos_y * attention, + dim=1, + keepdim=True) + var_x = expected_xx - expected_x * expected_x + var_y = expected_yy - expected_y * expected_y + var_xy = expected_xy - expected_x * expected_y + # stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix + feature_covar = torch.cat([var_x, var_xy, var_xy, var_y], + 1).reshape(-1, self._num_kp, 2, 2) + feature_keypoints = (feature_keypoints, feature_covar) + + if isinstance(feature_keypoints, tuple): + self.kps = (feature_keypoints[0].detach(), + feature_keypoints[1].detach()) + else: + self.kps = feature_keypoints.detach() + return feature_keypoints diff --git a/src/unifolm_wma/models/diffusion_head/common/__init__.py b/src/unifolm_wma/models/diffusion_head/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py b/src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py new file mode 100644 index 0000000..6f4f60a --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py @@ -0,0 +1,83 @@ +from diffusers.optimization import (Union, SchedulerType, Optional, Optimizer, + TYPE_TO_SCHEDULER_FUNCTION) + + +def get_scheduler(name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + **kwargs): + """ + Added kwargs vs diffuser's original implementation + + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer, **kwargs) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError( + f"{name} requires `num_warmup_steps`, please provide that argument." + ) + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, + num_warmup_steps=num_warmup_steps, + **kwargs) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError( + f"{name} requires `num_training_steps`, please provide that argument." + ) + + return schedule_func(optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **kwargs) + + +import torch +from torch.optim.lr_scheduler import _LRScheduler +import pytorch_lightning as pl +from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType + + +class SelectiveLRScheduler(_LRScheduler): + + def __init__(self, + optimizer, + base_scheduler, + group_indices, + default_lr=[1e-5, 1e-4], + last_epoch=-1): + self.base_scheduler = base_scheduler + self.group_indices = group_indices # Indices of parameter groups to update + self.default_lr = default_lr + super().__init__(optimizer, last_epoch) + + def step(self, epoch=None): + self.base_scheduler.step() + base_lrs = self.base_scheduler.get_last_lr() + + for idx, group in enumerate(self.optimizer.param_groups): + if idx in self.group_indices: + group['lr'] = base_lrs[idx] + else: + # Reset the learning rate to its initial value + group['lr'] = self.default_lr[idx] diff --git a/src/unifolm_wma/models/diffusion_head/common/module_attr_mixin.py b/src/unifolm_wma/models/diffusion_head/common/module_attr_mixin.py new file mode 100644 index 0000000..e33efe2 --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/common/module_attr_mixin.py @@ -0,0 +1,16 @@ +import torch.nn as nn + + +class ModuleAttrMixin(nn.Module): + + def __init__(self): + super().__init__() + self._dummy_variable = nn.Parameter() + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype diff --git a/src/unifolm_wma/models/diffusion_head/common/pytorch_util.py b/src/unifolm_wma/models/diffusion_head/common/pytorch_util.py new file mode 100644 index 0000000..a8913e7 --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/common/pytorch_util.py @@ -0,0 +1,91 @@ +import collections +import torch +import torch.nn as nn +from typing import Dict, Callable, List + + +def dict_apply( + x: Dict[str, torch.Tensor], + func: Callable[[torch.Tensor], + torch.Tensor]) -> Dict[str, torch.Tensor]: + result = dict() + for key, value in x.items(): + if isinstance(value, dict): + result[key] = dict_apply(value, func) + else: + result[key] = func(value) + return result + + +def pad_remaining_dims(x, target): + assert x.shape == target.shape[:len(x.shape)] + return x.reshape(x.shape + (1, ) * (len(target.shape) - len(x.shape))) + + +def dict_apply_split( + x: Dict[str, torch.Tensor], split_func: Callable[[torch.Tensor], + Dict[str, torch.Tensor]] +) -> Dict[str, torch.Tensor]: + results = collections.defaultdict(dict) + for key, value in x.items(): + result = split_func(value) + for k, v in result.items(): + results[k][key] = v + return results + + +def dict_apply_reduce( + x: List[Dict[str, + torch.Tensor]], reduce_func: Callable[[List[torch.Tensor]], + torch.Tensor] +) -> Dict[str, torch.Tensor]: + result = dict() + for key in x[0].keys(): + result[key] = reduce_func([x_[key] for x_ in x]) + return result + + +def replace_submodules(root_module: nn.Module, predicate: Callable[[nn.Module], + bool], + func: Callable[[nn.Module], nn.Module]) -> nn.Module: + """ + predicate: Return true if the module is to be replaced. + func: Return new module to use. + """ + if predicate(root_module): + return func(root_module) + + bn_list = [ + k.split('.') + for k, m in root_module.named_modules(remove_duplicate=True) + if predicate(m) + ] + for *parent, k in bn_list: + parent_module = root_module + if len(parent) > 0: + parent_module = root_module.get_submodule('.'.join(parent)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all BN are replaced + bn_list = [ + k.split('.') + for k, m in root_module.named_modules(remove_duplicate=True) + if predicate(m) + ] + assert len(bn_list) == 0 + return root_module + + +def optimizer_to(optimizer, device): + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device=device) + return optimizer diff --git a/src/unifolm_wma/models/diffusion_head/common/tensor_util.py b/src/unifolm_wma/models/diffusion_head/common/tensor_util.py new file mode 100644 index 0000000..98e962a --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/common/tensor_util.py @@ -0,0 +1,960 @@ +""" +A collection of utilities for working with nested tensor structures consisting +of numpy arrays and torch tensors. +""" +import collections +import numpy as np +import torch + + +def recursive_dict_list_tuple_apply(x, type_func_dict): + """ + Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of + {data_type: function_to_apply}. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + type_func_dict (dict): a mapping from data types to the functions to be + applied for each data type. + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + assert (list not in type_func_dict) + assert (tuple not in type_func_dict) + assert (dict not in type_func_dict) + + if isinstance(x, (dict, collections.OrderedDict)): + new_x = collections.OrderedDict() if isinstance( + x, collections.OrderedDict) else dict() + for k, v in x.items(): + new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict) + return new_x + elif isinstance(x, (list, tuple)): + ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x] + if isinstance(x, tuple): + ret = tuple(ret) + return ret + else: + for t, f in type_func_dict.items(): + if isinstance(x, t): + return f(x) + else: + raise NotImplementedError('Cannot handle data type %s' % + str(type(x))) + + +def map_tensor(x, func): + """ + Apply function @func to torch.Tensor objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each tensor + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply(x, { + torch.Tensor: func, + type(None): lambda x: x, + }) + + +def map_ndarray(x, func): + """ + Apply function @func to np.ndarray objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply(x, { + np.ndarray: func, + type(None): lambda x: x, + }) + + +def map_tensor_ndarray(x, tensor_func, ndarray_func): + """ + Apply function @tensor_func to torch.Tensor objects and @ndarray_func to + np.ndarray objects in a nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + tensor_func (function): function to apply to each tensor + ndarray_Func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: tensor_func, + np.ndarray: ndarray_func, + type(None): lambda x: x, + }) + + +def clone(x): + """ + Clones all torch tensors and numpy arrays in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x.clone(), + np.ndarray: lambda x: x.copy(), + type(None): lambda x: x, + }) + + +def detach(x): + """ + Detaches all torch tensors in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply(x, { + torch.Tensor: lambda x: x.detach(), + }) + + +def to_batch(x): + """ + Introduces a leading batch dimension of 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x[None, ...], + np.ndarray: lambda x: x[None, ...], + type(None): lambda x: x, + }) + + +def to_sequence(x): + """ + Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x[:, None, ...], + np.ndarray: lambda x: x[:, None, ...], + type(None): lambda x: x, + }) + + +def index_at_time(x, ind): + """ + Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in + nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + ind (int): index + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x[:, ind, ...], + np.ndarray: lambda x: x[:, ind, ...], + type(None): lambda x: x, + }) + + +def unsqueeze(x, dim): + """ + Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays + in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + dim (int): dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x.unsqueeze(dim=dim), + np.ndarray: lambda x: np.expand_dims(x, axis=dim), + type(None): lambda x: x, + }) + + +def contiguous(x): + """ + Makes all torch tensors and numpy arrays contiguous in nested dictionary or + list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x.contiguous(), + np.ndarray: lambda x: np.ascontiguousarray(x), + type(None): lambda x: x, + }) + + +def to_device(x, device): + """ + Sends all torch tensors in nested dictionary or list or tuple to device + @device, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x, d=device: x.to(d), + type(None): lambda x: x, + }) + + +def to_tensor(x): + """ + Converts all numpy arrays in nested dictionary or list or tuple to + torch tensors (and leaves existing torch Tensors as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x, + np.ndarray: lambda x: torch.from_numpy(x), + type(None): lambda x: x, + }) + + +def to_numpy(x): + """ + Converts all torch tensors in nested dictionary or list or tuple to + numpy (and leaves existing numpy arrays as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy() + else: + return tensor.detach().numpy() + + return recursive_dict_list_tuple_apply(x, { + torch.Tensor: f, + np.ndarray: lambda x: x, + type(None): lambda x: x, + }) + + +def to_list(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to a list, and returns a new nested structure. Useful for + json encoding. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy().tolist() + else: + return tensor.detach().numpy().tolist() + + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: f, + np.ndarray: lambda x: x.tolist(), + type(None): lambda x: x, + }) + + +def to_float(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to float type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x.float(), + np.ndarray: lambda x: x.astype(np.float32), + type(None): lambda x: x, + }) + + +def to_uint8(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to uint8 type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x.byte(), + np.ndarray: lambda x: x.astype(np.uint8), + type(None): lambda x: x, + }) + + +def to_torch(x, device): + """ + Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to + torch tensors on device @device and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return to_device(to_float(to_tensor(x)), device) + + +def to_one_hot_single(tensor, num_class): + """ + Convert tensor to one-hot representation, assuming a certain number of total class labels. + + Args: + tensor (torch.Tensor): tensor containing integer labels + num_class (int): number of classes + + Returns: + x (torch.Tensor): tensor containing one-hot representation of labels + """ + x = torch.zeros(tensor.size() + (num_class, )).to(tensor.device) + x.scatter_(-1, tensor.unsqueeze(-1), 1) + return x + + +def to_one_hot(tensor, num_class): + """ + Convert all tensors in nested dictionary or list or tuple to one-hot representation, + assuming a certain number of total class labels. + + Args: + tensor (dict or list or tuple): a possibly nested dictionary or list or tuple + num_class (int): number of classes + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(tensor, + func=lambda x, nc=num_class: to_one_hot_single(x, nc)) + + +def flatten_single(x, begin_axis=1): + """ + Flatten a tensor in all dimensions from @begin_axis onwards. + + Args: + x (torch.Tensor): tensor to flatten + begin_axis (int): which axis to flatten from + + Returns: + y (torch.Tensor): flattened tensor + """ + fixed_size = x.size()[:begin_axis] + _s = list(fixed_size) + [-1] + return x.reshape(*_s) + + +def flatten(x, begin_axis=1): + """ + Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): which axis to flatten from + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply(x, { + torch.Tensor: + lambda x, b=begin_axis: flatten_single(x, begin_axis=b), + }) + + +def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions in a tensor to a target dimension. + + Args: + x (torch.Tensor): tensor to reshape + begin_axis (int): begin dimension + end_axis (int): end dimension + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (torch.Tensor): reshaped tensor + """ + assert (begin_axis <= end_axis) + assert (begin_axis >= 0) + assert (end_axis < len(x.shape)) + assert (isinstance(target_dims, (tuple, list))) + s = x.shape + final_s = [] + for i in range(len(s)): + if i == begin_axis: + final_s.extend(target_dims) + elif i < begin_axis or i > end_axis: + final_s.append(s[i]) + return x.reshape(*final_s) + + +def reshape_dimensions(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions for all tensors in nested dictionary or list or tuple + to a target dimension. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: + lambda x, b=begin_axis, e=end_axis, t=target_dims: + reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t), + np.ndarray: + lambda x, b=begin_axis, e=end_axis, t=target_dims: + reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t), + type(None): + lambda x: x, + }) + + +def join_dimensions(x, begin_axis, end_axis): + """ + Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for + all tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: + lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1]), + np.ndarray: + lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1]), + type(None): + lambda x: x, + }) + + +def expand_at_single(x, size, dim): + """ + Expand a tensor at a single dimension @dim by @size + + Args: + x (torch.Tensor): input tensor + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (torch.Tensor): expanded tensor + """ + assert dim < x.ndimension() + assert x.shape[dim] == 1 + expand_dims = [-1] * x.ndimension() + expand_dims[dim] = size + return x.expand(*expand_dims) + + +def expand_at(x, size, dim): + """ + Expand all tensors in nested dictionary or list or tuple at a single + dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d)) + + +def unsqueeze_expand_at(x, size, dim): + """ + Unsqueeze and expand a tensor at a dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to unsqueeze and expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze(x, dim) + return expand_at(x, size, dim) + + +def repeat_by_expand_at(x, repeats, dim): + """ + Repeat a dimension by combining expand and reshape operations. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + repeats (int): number of times to repeat the target dimension + dim (int): dimension to repeat on + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze_expand_at(x, repeats, dim + 1) + return join_dimensions(x, dim, dim + 1) + + +def named_reduce_single(x, reduction, dim): + """ + Reduce tensor at a dimension by named reduction functions. + + Args: + x (torch.Tensor): tensor to be reduced + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (torch.Tensor): reduced tensor + """ + assert x.ndimension() > dim + assert reduction in ["sum", "max", "mean", "flatten"] + if reduction == "flatten": + x = flatten(x, begin_axis=dim) + elif reduction == "max": + x = torch.max(x, dim=dim)[0] # [B, D] + elif reduction == "sum": + x = torch.sum(x, dim=dim) + else: + x = torch.mean(x, dim=dim) + return x + + +def named_reduce(x, reduction, dim): + """ + Reduces all tensors in nested dictionary or list or tuple at a dimension + using a named reduction function. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor( + x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d)) + + +def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices): + """ + This function indexes out a target dimension of a tensor in a structured way, + by allowing a different value to be selected for each member of a flat index + tensor (@indices) corresponding to a source dimension. This can be interpreted + as moving along the source dimension, using the corresponding index value + in @indices to select values for all other dimensions outside of the + source and target dimensions. A common use case is to gather values + in target dimension 1 for each batch member (target dimension 0). + + Args: + x (torch.Tensor): tensor to gather values for + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out + """ + assert len(indices.shape) == 1 + assert x.shape[source_dim] == indices.shape[0] + + # unsqueeze in all dimensions except the source dimension + new_shape = [1] * x.ndimension() + new_shape[source_dim] = -1 + indices = indices.reshape(*new_shape) + + # repeat in all dimensions - but preserve shape of source dimension, + # and make sure target_dimension has singleton dimension + expand_shape = list(x.shape) + expand_shape[source_dim] = -1 + expand_shape[target_dim] = 1 + indices = indices.expand(*expand_shape) + + out = x.gather(dim=target_dim, index=indices) + return out.squeeze(target_dim) + + +def gather_along_dim_with_dim(x, target_dim, source_dim, indices): + """ + Apply @gather_along_dim_with_dim_single to all tensors in a nested + dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(x, + lambda y, t=target_dim, s=source_dim, i=indices: + gather_along_dim_with_dim_single(y, t, s, i)) + + +def gather_sequence_single(seq, indices): + """ + Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in + the batch given an index for each sequence. + + Args: + seq (torch.Tensor): tensor with leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Return: + y (torch.Tensor): indexed tensor of shape [B, ....] + """ + return gather_along_dim_with_dim_single(seq, + target_dim=1, + source_dim=0, + indices=indices) + + +def gather_sequence(seq, indices): + """ + Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch + for tensors with leading dimensions [B, T, ...]. + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Returns: + y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...] + """ + return gather_along_dim_with_dim(seq, + target_dim=1, + source_dim=0, + indices=indices) + + +def pad_sequence_single(seq, + padding, + batched=False, + pad_same=True, + pad_values=None): + """ + Pad input tensor or array @seq in the time dimension (dimension 1). + + Args: + seq (np.ndarray or torch.Tensor): sequence to be padded + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (np.ndarray or torch.Tensor) + """ + assert isinstance(seq, (np.ndarray, torch.Tensor)) + assert pad_same or pad_values is not None + if pad_values is not None: + assert isinstance(pad_values, float) + repeat_func = np.repeat if isinstance( + seq, np.ndarray) else torch.repeat_interleave + concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat + ones_like_func = np.ones_like if isinstance( + seq, np.ndarray) else torch.ones_like + seq_dim = 1 if batched else 0 + + begin_pad = [] + end_pad = [] + + if padding[0] > 0: + pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values + begin_pad.append(repeat_func(pad, padding[0], seq_dim)) + if padding[1] > 0: + pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values + end_pad.append(repeat_func(pad, padding[1], seq_dim)) + + return concat_func(begin_pad + [seq] + end_pad, seq_dim) + + +def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None): + """ + Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1). + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (dict or list or tuple) + """ + return recursive_dict_list_tuple_apply( + seq, { + torch.Tensor: + lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: + pad_sequence_single(x, p, b, ps, pv), + np.ndarray: + lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: + pad_sequence_single(x, p, b, ps, pv), + type(None): + lambda x: x, + }) + + +def assert_size_at_dim_single(x, size, dim, msg): + """ + Ensure that array or tensor @x has size @size in dim @dim. + + Args: + x (np.ndarray or torch.Tensor): input array or tensor + size (int): size that tensors should have at @dim + dim (int): dimension to check + msg (str): text to display if assertion fails + """ + assert x.shape[dim] == size, msg + + +def assert_size_at_dim(x, size, dim, msg): + """ + Ensure that arrays and tensors in nested dictionary or list or tuple have + size @size in dim @dim. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size that tensors should have at @dim + dim (int): dimension to check + """ + map_tensor( + x, + lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m)) + + +def get_shape(x): + """ + Get all shapes of arrays and tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple that contains each array or + tensor's shape + """ + return recursive_dict_list_tuple_apply( + x, { + torch.Tensor: lambda x: x.shape, + np.ndarray: lambda x: x.shape, + type(None): lambda x: x, + }) + + +def list_of_flat_dict_to_dict_of_list(list_of_dict): + """ + Helper function to go from a list of flat dictionaries to a dictionary of lists. + By "flat" we mean that none of the values are dictionaries, but are numpy arrays, + floats, etc. + + Args: + list_of_dict (list): list of flat dictionaries + + Returns: + dict_of_list (dict): dictionary of lists + """ + assert isinstance(list_of_dict, list) + dic = collections.OrderedDict() + for i in range(len(list_of_dict)): + for k in list_of_dict[i]: + if k not in dic: + dic[k] = [] + dic[k].append(list_of_dict[i][k]) + return dic + + +def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''): + """ + Flatten a nested dict or list to a list. + + For example, given a dict + { + a: 1 + b: { + c: 2 + } + c: 3 + } + + the function would return [(a, 1), (b_c, 2), (c, 3)] + + Args: + d (dict, list): a nested dict or list to be flattened + parent_key (str): recursion helper + sep (str): separator for nesting keys + item_key (str): recursion helper + Returns: + list: a list of (key, value) tuples + """ + items = [] + if isinstance(d, (tuple, list)): + new_key = parent_key + sep + item_key if len( + parent_key) > 0 else item_key + for i, v in enumerate(d): + items.extend( + flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i))) + return items + elif isinstance(d, dict): + new_key = parent_key + sep + item_key if len( + parent_key) > 0 else item_key + for k, v in d.items(): + assert isinstance(k, str) + items.extend( + flatten_nested_dict_list(v, new_key, sep=sep, item_key=k)) + return items + else: + new_key = parent_key + sep + item_key if len( + parent_key) > 0 else item_key + return [(new_key, d)] + + +def time_distributed(inputs, + op, + activation=None, + inputs_as_kwargs=False, + inputs_as_args=False, + **kwargs): + """ + Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the + batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...]. + Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping + outputs to [B, T, ...]. + + Args: + inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + op: a layer op that accepts inputs + activation: activation to apply at the output + inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op + inputs_as_args (bool) whether to feed input as a args list to the op + kwargs (dict): other kwargs to supply to the op + + Returns: + outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T]. + """ + batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2] + inputs = join_dimensions(inputs, 0, 1) + if inputs_as_kwargs: + outputs = op(**inputs, **kwargs) + elif inputs_as_args: + outputs = op(*inputs, **kwargs) + else: + outputs = op(inputs, **kwargs) + + if activation is not None: + outputs = map_tensor(outputs, activation) + outputs = reshape_dimensions(outputs, + begin_axis=0, + end_axis=0, + target_dims=(batch_size, seq_len)) + return outputs diff --git a/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py b/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py new file mode 100644 index 0000000..12378a1 --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py @@ -0,0 +1,701 @@ +import logging +import torch +import torch.nn as nn +import einops + +from einops import rearrange, repeat +from typing import Union + +from unifolm_wma.models.diffusion_head.conv1d_components import ( + Downsample1d, Upsample1d, Conv1dBlock) +from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb +from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax + +from unifolm_wma.utils.basics import zero_module +from unifolm_wma.utils.common import ( + checkpoint, + exists, + default, +) +from unifolm_wma.utils.utils import instantiate_from_config + +logger = logging.getLogger(__name__) + + +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0., + relative_position=False): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout)) + + def efficient_forward(self, x, context=None): + spatial_self_attn = (context is None) + k_ip, v_ip, out_ip = None, None, None + + q = self.to_q(x) + if spatial_self_attn: + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3).reshape(b, t.shape[ + 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous(), + (q, k, v), + ) + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, + k, + v, + attn_bias=None, + op=None) + out = (out.unsqueeze(0).reshape( + b, self.heads, out.shape[1], + self.dim_head).permute(0, 2, 1, + 3).reshape(b, out.shape[1], + self.heads * self.dim_head)) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attention_cls=None): + super().__init__() + attn_cls = CrossAttention if attention_cls is None else attention_cls + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, **kwargs): + ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments + input_tuple = ( + x, + ) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments + if context is not None: + input_tuple = (x, context) + return checkpoint(self._forward, input_tuple, self.parameters(), + self.checkpoint) + + def _forward(self, x, context=None, mask=None): + x = self.attn1(self.norm1(x), + context=context if self.disable_self_attn else None, + mask=mask) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + + +class ActionLatentImageCrossAttention(nn.Module): + + def __init__(self, + in_channels, + in_dim, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + use_checkpoint=True, + disable_self_attn=False, + use_linear=True): + super().__init__() + """ + in_channels: action input dim + + """ + self.in_channels = in_channels + self.in_dim = in_dim + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=8, + num_channels=in_channels, + eps=1e-6, + affine=True) + + self.proj_in_action = nn.Linear(in_dim, inner_dim) + self.proj_in_cond = nn.Linear(context_dim, inner_dim) + self.proj_out = zero_module(nn.Linear(inner_dim, in_dim)) + self.use_linear = use_linear + + attention_cls = None + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock(inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + attention_cls=attention_cls) + for d in range(depth) + ]) + + def forward(self, x, context=None, **kwargs): + ba, ca, da = x.shape + b, t, c, h, w = context.shape + context = rearrange(context, 'b t c h w -> b (t h w) c').contiguous() + + x_in = x + x = self.norm(x) # ba x ja x d_in + if self.use_linear: + x = self.proj_in_action(x) + context = self.proj_in_cond(context) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context, **kwargs) + if self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class ConditionalResidualBlock1D(nn.Module): + + def __init__(self, + in_channels, + out_channels, + cond_dim, + kernel_size=3, + n_groups=8, + cond_predict_scale=True, + use_linear_act_proj=False): + super().__init__() + + self.blocks = nn.ModuleList([ + Conv1dBlock(in_channels, + out_channels, + kernel_size, + n_groups=n_groups), + Conv1dBlock(out_channels, + out_channels, + kernel_size, + n_groups=n_groups), + ]) + + self.cond_predict_scale = cond_predict_scale + self.use_linear_act_proj = use_linear_act_proj + self.out_channels = out_channels + # FiLM modulation https://arxiv.org/abs/1709.07871 + # predicts per-channel scale and bias + cond_channels = out_channels + if cond_predict_scale and use_linear_act_proj: + cond_channels = out_channels * 2 + self.cond_encoder = nn.Sequential( + nn.Mish(), + nn.Linear(cond_dim, cond_channels), + ) + # make sure dimensions compatible + self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ + if in_channels != out_channels else nn.Identity() + + def forward(self, x, cond=None): + ''' + x : [ batch_size x in_channels x horizon ] + cond : [ batch_size x cond_dim] + + returns: + out : [ batch_size x out_channels x horizon ] + ''' + B, T, _ = cond.shape + + out = self.blocks[0](x) + if self.cond_predict_scale: + embed = self.cond_encoder(cond) + if self.use_linear_act_proj: + embed = embed.reshape(B * T, -1) + embed = embed.reshape(-1, 2, self.out_channels, 1) + else: + embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) + scale = embed[:, 0, ...] + bias = embed[:, 1, ...] + out = scale * out + bias + # else: + # out = out + embed + out = self.blocks[1](out) + out = out + self.residual_conv(x) + return out + + +class ConditionalUnet1D(nn.Module): + + def __init__(self, + input_dim, + n_obs_steps=1, + local_cond_dim=None, + global_cond_dim=None, + diffusion_step_embed_dim=256, + down_dims=[256, 512, 1024], + kernel_size=3, + n_groups=8, + cond_predict_scale=False, + horizon=16, + num_head_channels=64, + use_linear_attn=True, + use_linear_act_proj=True, + act_proj_dim=32, + cond_cross_attention=False, + context_dims=None, + image_size=None, + imagen_cond_gradient=False, + last_frame_only=False, + use_imagen_mid_only=False, + use_z_only=False, + spatial_num_kp=32, + obs_encoder_config=None): + super().__init__() + + self.n_obs_steps = n_obs_steps + self.obs_encoder = instantiate_from_config(obs_encoder_config) + + all_dims = [input_dim] + list(down_dims) + start_dim = down_dims[0] + + dsed = diffusion_step_embed_dim + diffusion_step_encoder = nn.Sequential( + SinusoidalPosEmb(dsed), + nn.Linear(dsed, dsed * 4), + nn.Mish(), + nn.Linear(dsed * 4, dsed), + ) + cond_dim = dsed + self.obs_encoder.output_shape()[-1] * self.n_obs_steps + in_out = list(zip(all_dims[:-1], all_dims[1:])) + local_cond_encoder = None + down_modules = nn.ModuleList([]) + + dim_a_list = [] + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + if ind == 0: + dim_a = horizon + else: + dim_a = horizon // 2 * ind + dim_a_list.append(dim_a) + + # for attention + num_heads = dim_out // num_head_channels + dim_head = num_head_channels + if use_linear_act_proj: + if use_imagen_mid_only: + cur_cond_dim = cond_dim + 2 * context_dims[-1] + elif use_z_only: + cur_cond_dim = cond_dim + 2 * spatial_num_kp + else: + cur_cond_dim = cond_dim + 2 * context_dims[ind] + else: + cur_cond_dim = cond_dim + horizon * context_dims[ind] + + down_modules.append( + nn.ModuleList([ + ConditionalResidualBlock1D( + dim_in, + dim_out, + cond_dim=cur_cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + use_linear_act_proj=use_linear_act_proj), + ConditionalResidualBlock1D( + dim_out, + dim_out, + cond_dim=cur_cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + use_linear_act_proj=use_linear_act_proj), + ActionLatentImageCrossAttention( + dim_out, + dim_a, + num_heads, + dim_head, + context_dim=context_dims[ind], + use_linear=use_linear_attn) + if cond_cross_attention else nn.Identity(), + Downsample1d(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = all_dims[-1] + self.mid_modules = nn.ModuleList([ + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cur_cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + use_linear_act_proj=use_linear_act_proj), + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cur_cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + use_linear_act_proj=use_linear_act_proj), + ActionLatentImageCrossAttention(mid_dim, + dim_a_list[-1], + num_heads, + dim_head, + context_dim=context_dims[-1], + use_linear=use_linear_attn) + if cond_cross_attention else nn.Identity(), + ]) + + up_modules = nn.ModuleList([]) + context_dims = context_dims[::-1] + for ind, (dim_in, dim_out) in enumerate( + reversed(in_out[1:] + [(down_dims[-1], down_dims[-1])])): + is_last = ind >= (len(in_out) - 1) + if use_linear_act_proj: + if use_imagen_mid_only: + cur_cond_dim = cond_dim + 2 * context_dims[0] + elif use_z_only: + cur_cond_dim = cond_dim + 2 * spatial_num_kp + else: + cur_cond_dim = cond_dim + 2 * context_dims[ind] + else: + cur_cond_dim = cond_dim + horizon * context_dims[ind] + up_modules.append( + nn.ModuleList([ + ConditionalResidualBlock1D( + dim_out + dim_in, + dim_in, + cond_dim=cur_cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + use_linear_act_proj=use_linear_act_proj), + ConditionalResidualBlock1D( + dim_in, + dim_in, + cond_dim=cur_cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + use_linear_act_proj=use_linear_act_proj), + ActionLatentImageCrossAttention( + dim_in, + dim_a_list.pop(), + num_heads, + dim_head, + context_dim=context_dims[ind], + use_linear=use_linear_attn) + if cond_cross_attention else nn.Identity(), + Upsample1d(dim_in) if not is_last else nn.Identity() + ])) + + final_conv = nn.Sequential( + Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), + nn.Conv1d(start_dim, input_dim, 1), + ) + + if use_z_only: + h, w = image_size + self.spatial_softmax_blocks = nn.ModuleList( + [SpatialSoftmax((4, h, w), spatial_num_kp)]) + else: + self.spatial_softmax_blocks = nn.ModuleList([]) + context_dims = context_dims[::-1] + for ind, context_dim in enumerate(context_dims): + h, w = image_size + if ind != 0: + h //= 2**ind + w //= 2**ind + net = SpatialSoftmax((context_dim, h, w), context_dim) + self.spatial_softmax_blocks.append(net) + self.spatial_softmax_blocks.append(net) + self.spatial_softmax_blocks += self.spatial_softmax_blocks[ + 0:4][::-1] + + self.diffusion_step_encoder = diffusion_step_encoder + self.local_cond_encoder = local_cond_encoder + self.up_modules = up_modules + self.down_modules = down_modules + self.final_conv = final_conv + + self.cond_cross_attention = cond_cross_attention + self.use_linear_act_proj = use_linear_act_proj + + self.proj_in_action = nn.Sequential(nn.Linear(1, act_proj_dim), + nn.LayerNorm(act_proj_dim)) + self.proj_in_horizon = nn.Sequential(nn.Linear(horizon, act_proj_dim), + nn.LayerNorm(act_proj_dim)) + self.proj_out_action = nn.Sequential(nn.LayerNorm(act_proj_dim), + nn.Linear(act_proj_dim, 1)) + self.proj_out_horizon = nn.Sequential(nn.LayerNorm(act_proj_dim), + nn.Linear(act_proj_dim, horizon)) + logger.info("number of parameters: %e", + sum(p.numel() for p in self.parameters())) + + self.imagen_cond_gradient = imagen_cond_gradient + self.use_imagen_mid_only = use_imagen_mid_only + self.use_z_only = use_z_only + self.spatial_num_kp = spatial_num_kp + self.last_frame_only = last_frame_only + self.horizon = horizon + + def forward(self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + imagen_cond=None, + cond=None, + **kwargs): + """ + sample: (B,T,input_dim) + timestep: (B,) or int, diffusion step + imagen_cond: a list of hidden info from video gen unet + cond: dict: + image: (B, 3, To, h, w) + agent_pos: (B, Ta, d) + output: (B,T,input_dim) + """ + + if not self.imagen_cond_gradient: + imagen_cond = [c.detach() for c in imagen_cond] + + cond = {'image': cond[0], 'agent_pos': cond[1]} + + cond['image'] = cond['image'].permute(0, 2, 1, 3, + 4) + cond['image'] = rearrange(cond['image'], 'b t c h w -> (b t) c h w') + cond['agent_pos'] = rearrange(cond['agent_pos'], 'b t d -> (b t) d') + + B, T, D = sample.shape + if self.use_linear_act_proj: + sample = self.proj_in_action(sample.unsqueeze(-1)) + global_cond = self.obs_encoder(cond) + global_cond = rearrange(global_cond, + '(b t) d -> b 1 (t d)', + b=B, + t=self.n_obs_steps) + global_cond = repeat(global_cond, + 'b c d -> b (repeat c) d', + repeat=T) + else: + sample = einops.rearrange(sample, 'b h t -> b t h') + sample = self.proj_in_horizon(sample) + robo_state_cond = rearrange(robo_state_cond, 'b t d -> b 1 (t d)') + robo_state_cond = repeat(robo_state_cond, + 'b c d -> b (repeat c) d', + repeat=2) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], + dtype=torch.long, + device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + # Broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + global_feature = self.diffusion_step_encoder(timesteps) + (imagen_cond_down, imagen_cond_mid, imagen_cond_up + ) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE + + x = sample if not self.use_linear_act_proj else sample.reshape( + B * T, D, -1) + h = [] + for idx, modules in enumerate(self.down_modules): + if self.cond_cross_attention: + (resnet, resnet2, crossatten, downsample) = modules + else: + (resnet, resnet2, _, downsample) = modules + + # Access the cond from the unet embeds from video unet + if self.use_imagen_mid_only: + imagen_cond = imagen_cond_mid + elif self.use_z_only: + imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4) + else: + imagen_cond = imagen_cond_down[idx] + if self.last_frame_only: + imagen_cond = imagen_cond[:, -1].unsqueeze(1) + imagen_cond = repeat(imagen_cond, + 'b t c h w -> b (repeat t) c h w', + repeat=self.horizon) + imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w') + if self.use_imagen_mid_only: + imagen_cond = self.spatial_softmax_blocks[len( + self.spatial_softmax_blocks) // 2](imagen_cond) + elif self.use_z_only: + imagen_cond = self.spatial_softmax_blocks[0](imagen_cond) + else: + imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond) + imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B) + + if self.use_linear_act_proj: + imagen_cond = imagen_cond.reshape(B, T, -1) + cur_global_feature = global_feature.unsqueeze( + 1).repeat_interleave(repeats=T, dim=1) + else: + imagen_cond = imagen_cond.permute(0, 3, 1, 2) + imagen_cond = imagen_cond.reshape(B, 2, -1) + cur_global_feature = global_feature.unsqueeze( + 1).repeat_interleave(repeats=2, dim=1) + cur_global_feature = torch.cat( + [cur_global_feature, global_cond, imagen_cond], axis=-1) + x = resnet(x, cur_global_feature) + x = resnet2(x, cur_global_feature) + h.append(x) + x = downsample(x) + + #>>> mide blocks + resnet, resnet2, _ = self.mid_modules + # Access the cond from the unet embeds from video unet + if self.use_z_only: + imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4) + else: + imagen_cond = imagen_cond_mid + if self.last_frame_only: + imagen_cond = imagen_cond[:, -1].unsqueeze(1) + imagen_cond = repeat(imagen_cond, + 'b t c h w -> b (repeat t) c h w', + repeat=self.horizon) + imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w') + idx += 1 + if self.use_z_only: + imagen_cond = self.spatial_softmax_blocks[0](imagen_cond) + else: + imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond) + imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B) + if self.use_linear_act_proj: + imagen_cond = imagen_cond.reshape(B, T, -1) + cur_global_feature = global_feature.unsqueeze(1).repeat_interleave( + repeats=T, dim=1) + else: + imagen_cond = imagen_cond.permute(0, 3, 1, 2) + imagen_cond = imagen_cond.reshape(B, 2, -1) + cur_global_feature = global_feature.unsqueeze(1).repeat_interleave( + repeats=2, dim=1) + cur_global_feature = torch.cat( + [cur_global_feature, global_cond, imagen_cond], axis=-1) + x = resnet(x, cur_global_feature) + x = resnet2(x, cur_global_feature) + + #>>> up blocks + idx += 1 + for jdx, modules in enumerate(self.up_modules): + if self.cond_cross_attention: + (resnet, resnet2, crossatten, upsample) = modules + else: + (resnet, resnet2, _, upsample) = modules + + # Access the cond from the unet embeds from video unet + if self.use_imagen_mid_only: + imagen_cond = imagen_cond_mid + elif self.use_z_only: + imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4) + else: + imagen_cond = imagen_cond_up[jdx] + if self.last_frame_only: + imagen_cond = imagen_cond[:, -1].unsqueeze(1) + imagen_cond = repeat(imagen_cond, + 'b t c h w -> b (repeat t) c h w', + repeat=self.horizon) + imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w') + if self.use_imagen_mid_only: + imagen_cond = self.spatial_softmax_blocks[len( + self.spatial_softmax_blocks) // 2](imagen_cond) + elif self.use_z_only: + imagen_cond = self.spatial_softmax_blocks[0](imagen_cond) + else: + imagen_cond = self.spatial_softmax_blocks[jdx + + idx](imagen_cond) + imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B) + + if self.use_linear_act_proj: + imagen_cond = imagen_cond.reshape(B, T, -1) + cur_global_feature = global_feature.unsqueeze( + 1).repeat_interleave(repeats=T, dim=1) + else: + imagen_cond = imagen_cond.permute(0, 3, 1, 2) + imagen_cond = imagen_cond.reshape(B, 2, -1) + cur_global_feature = global_feature.unsqueeze( + 1).repeat_interleave(repeats=2, dim=1) + + cur_global_feature = torch.cat( + [cur_global_feature, global_cond, imagen_cond], axis=-1) + + x = torch.cat((x, h.pop()), dim=1) + x = resnet(x, cur_global_feature) + x = resnet2(x, cur_global_feature) + x = upsample(x) + + x = self.final_conv(x) + + if self.use_linear_act_proj: + x = x.reshape(B, T, D, -1) + x = self.proj_out_action(x) + x = x.reshape(B, T, D) + else: + x = self.proj_out_horizon(x) + x = einops.rearrange(x, 'b t h -> b h t') + return x diff --git a/src/unifolm_wma/models/diffusion_head/conv1d_components.py b/src/unifolm_wma/models/diffusion_head/conv1d_components.py new file mode 100644 index 0000000..713ac6c --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/conv1d_components.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Downsample1d(nn.Module): + + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1d(nn.Module): + + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Conv1dBlock(nn.Module): + ''' + Conv1d --> GroupNorm --> Mish + ''' + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, + out_channels, + kernel_size, + padding=kernel_size // 2), + # Rearrange('batch channels horizon -> batch channels 1 horizon'), + nn.GroupNorm(n_groups, out_channels), + # Rearrange('batch channels 1 horizon -> batch channels horizon'), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +def test(): + cb = Conv1dBlock(256, 128, kernel_size=3) + x = torch.zeros((1, 256, 16)) + o = cb(x) diff --git a/src/unifolm_wma/models/diffusion_head/ema_model.py b/src/unifolm_wma/models/diffusion_head/ema_model.py new file mode 100644 index 0000000..c83a3b0 --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/ema_model.py @@ -0,0 +1,80 @@ +import copy +import torch +from torch.nn.modules.batchnorm import _BatchNorm + + +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__(self, + model, + update_after_step=0, + inv_gamma=1.0, + power=2 / 3, + min_value=0.0, + max_value=0.9999): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ + + self.averaged_model = model + self.averaged_model.eval() + self.averaged_model.requires_grad_(False) + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + self.decay = 0.0 + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma)**-self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + + @torch.no_grad() + def step(self, new_model): + self.decay = self.get_decay(self.optimization_step) + + all_dataptrs = set() + for module, ema_module in zip(new_model.modules(), + self.averaged_model.modules()): + for param, ema_param in zip(module.parameters(recurse=False), + ema_module.parameters(recurse=False)): + # iterative over immediate parameters only. + if isinstance(param, dict): + raise RuntimeError('Dict parameter not supported') + + if isinstance(module, _BatchNorm): + # skip batchnorms + ema_param.copy_(param.to(dtype=ema_param.dtype).data) + elif not param.requires_grad: + ema_param.copy_(param.to(dtype=ema_param.dtype).data) + else: + ema_param.mul_(self.decay) + ema_param.add_(param.data.to(dtype=ema_param.dtype), + alpha=1 - self.decay) + + # verify that iterating over module and then parameters is identical to parameters recursively. + # assert old_all_dataptrs == all_dataptrs + self.optimization_step += 1 diff --git a/src/unifolm_wma/models/diffusion_head/positional_embedding.py b/src/unifolm_wma/models/diffusion_head/positional_embedding.py new file mode 100644 index 0000000..1b1d646 --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/positional_embedding.py @@ -0,0 +1,19 @@ +import math +import torch +import torch.nn as nn + + +class SinusoidalPosEmb(nn.Module): + + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb diff --git a/src/unifolm_wma/models/diffusion_head/vision/__init__.py b/src/unifolm_wma/models/diffusion_head/vision/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py b/src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py new file mode 100644 index 0000000..d7b5408 --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py @@ -0,0 +1,322 @@ +import torch +import torch.nn as nn +import torchvision.transforms.functional as ttf +import unifolm_wma.models.diffusion_head.common.tensor_util as tu + + +class CropRandomizer(nn.Module): + """ + Randomly sample crops at input, and then average across crop features at output. + """ + + def __init__( + self, + input_shape, + crop_height, + crop_width, + num_crops=1, + pos_enc=False, + ): + """ + Args: + input_shape (tuple, list): shape of input (not including batch dimension) + crop_height (int): crop height + crop_width (int): crop width + num_crops (int): number of random crops to take + pos_enc (bool): if True, add 2 channels to the output to encode the spatial + location of the cropped pixels in the source image + """ + super().__init__() + + assert len(input_shape) == 3 # (C, H, W) + assert crop_height < input_shape[1] + assert crop_width < input_shape[2] + + self.input_shape = input_shape + self.crop_height = crop_height + self.crop_width = crop_width + self.num_crops = num_crops + self.pos_enc = pos_enc + + def output_shape_in(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. Corresponds to + the @forward_in operation, where raw inputs (usually observation modalities) + are passed in. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + + # outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because + # the number of crops are reshaped into the batch dimension, increasing the batch + # size from B to B * N + out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0] + return [out_c, self.crop_height, self.crop_width] + + def output_shape_out(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. Corresponds to + the @forward_out operation, where processed inputs (usually encoded observation + modalities) are passed in. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + + # since the forward_out operation splits [B * N, ...] -> [B, N, ...] + # and then pools to result in [B, ...], only the batch dimension changes, + # and so the other dimensions retain their shape. + return list(input_shape) + + def forward_in(self, inputs): + """ + Samples N random crops for each input in the batch, and then reshapes + inputs to [B * N, ...]. + """ + assert len( + inputs.shape) >= 3 # must have at least (C, H, W) dimensions + if self.training: + # generate random crops + out, _ = sample_random_image_crops( + images=inputs, + crop_height=self.crop_height, + crop_width=self.crop_width, + num_crops=self.num_crops, + pos_enc=self.pos_enc, + ) + # [B, N, ...] -> [B * N, ...] + return tu.join_dimensions(out, 0, 1) + else: + # take center crop during eval + out = ttf.center_crop(img=inputs, + output_size=(self.crop_height, + self.crop_width)) + if self.num_crops > 1: + B, C, H, W = out.shape + out = out.unsqueeze(1).expand(B, self.num_crops, C, H, + W).reshape(-1, C, H, W) + # [B * N, ...] + return out + + def forward_out(self, inputs): + """ + Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N + to result in shape [B, ...] to make sure the network output is consistent with + what would have happened if there were no randomization. + """ + if self.num_crops <= 1: + return inputs + else: + batch_size = (inputs.shape[0] // self.num_crops) + out = tu.reshape_dimensions(inputs, + begin_axis=0, + end_axis=0, + target_dims=(batch_size, + self.num_crops)) + return out.mean(dim=1) + + def forward(self, inputs): + return self.forward_in(inputs) + + def __repr__(self): + """Pretty print network.""" + header = '{}'.format(str(self.__class__.__name__)) + msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format( + self.input_shape, self.crop_height, self.crop_width, + self.num_crops) + return msg + + +def crop_image_from_indices(images, crop_indices, crop_height, crop_width): + """ + Crops images at the locations specified by @crop_indices. Crops will be + taken across all channels. + + Args: + images (torch.Tensor): batch of images of shape [..., C, H, W] + + crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where + N is the number of crops to take per image and each entry corresponds + to the pixel height and width of where to take the crop. Note that + the indices can also be of shape [..., 2] if only 1 crop should + be taken per image. Leading dimensions must be consistent with + @images argument. Each index specifies the top left of the crop. + Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where + H and W are the height and width of @images and CH and CW are + @crop_height and @crop_width. + + crop_height (int): height of crop to take + + crop_width (int): width of crop to take + + Returns: + crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width] + """ + + # make sure length of input shapes is consistent + assert crop_indices.shape[-1] == 2 + ndim_im_shape = len(images.shape) + ndim_indices_shape = len(crop_indices.shape) + assert (ndim_im_shape == ndim_indices_shape + + 1) or (ndim_im_shape == ndim_indices_shape + 2) + + # maybe pad so that @crop_indices is shape [..., N, 2] + is_padded = False + if ndim_im_shape == ndim_indices_shape + 2: + crop_indices = crop_indices.unsqueeze(-2) + is_padded = True + + # make sure leading dimensions between images and indices are consistent + assert images.shape[:-3] == crop_indices.shape[:-2] + + device = images.device + image_c, image_h, image_w = images.shape[-3:] + num_crops = crop_indices.shape[-2] + + # make sure @crop_indices are in valid range + assert (crop_indices[..., 0] >= 0).all().item() + assert (crop_indices[..., 0] < (image_h - crop_height)).all().item() + assert (crop_indices[..., 1] >= 0).all().item() + assert (crop_indices[..., 1] < (image_w - crop_width)).all().item() + + # convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window. + + # 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW] + crop_ind_grid_h = torch.arange(crop_height).to(device) + crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, + size=crop_width, + dim=-1) + # 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW] + crop_ind_grid_w = torch.arange(crop_width).to(device) + crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, + size=crop_height, + dim=0) + # combine into shape [CH, CW, 2] + crop_in_grid = torch.cat( + (crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1) + + # Add above grid with the offset index of each sampled crop to get 2d indices for each crop. + # After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2] + # shape array that tells us which pixels from the corresponding source image to grab. + grid_reshape = [1] * len(crop_indices.shape[:-1]) + [ + crop_height, crop_width, 2 + ] + all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze( + -2) + crop_in_grid.reshape(grid_reshape) + + # For using @torch.gather, convert to flat indices from 2D indices, and also + # repeat across the channel dimension. To get flat index of each pixel to grab for + # each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind + all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[ + ..., 1] # shape [..., N, CH, CW] + all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, + dim=-3) # shape [..., N, C, CH, CW] + all_crop_inds = tu.flatten(all_crop_inds, + begin_axis=-2) # shape [..., N, C, CH * CW] + + # Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds + images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4) + images_to_crop = tu.flatten(images_to_crop, begin_axis=-2) + crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds) + # [..., N, C, CH * CW] -> [..., N, C, CH, CW] + reshape_axis = len(crops.shape) - 1 + crops = tu.reshape_dimensions(crops, + begin_axis=reshape_axis, + end_axis=reshape_axis, + target_dims=(crop_height, crop_width)) + + if is_padded: + # undo padding -> [..., C, CH, CW] + crops = crops.squeeze(-4) + return crops + + +def sample_random_image_crops(images, + crop_height, + crop_width, + num_crops, + pos_enc=False): + """ + For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from + @images. + + Args: + images (torch.Tensor): batch of images of shape [..., C, H, W] + + crop_height (int): height of crop to take + + crop_width (int): width of crop to take + + num_crops (n): number of crops to sample + + pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial + encoding of the original source pixel locations. This means that the + output crops will contain information about where in the source image + it was sampled from. + + Returns: + crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width) + if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width) + + crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2) + """ + device = images.device + + # maybe add 2 channels of spatial encoding to the source image + source_im = images + if pos_enc: + # spatial encoding [y, x] in [0, 1] + h, w = source_im.shape[-2:] + pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w)) + pos_y = pos_y.float().to(device) / float(h) + pos_x = pos_x.float().to(device) / float(w) + position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W] + + # unsqueeze and expand to match leading dimensions -> shape [..., C, H, W] + leading_shape = source_im.shape[:-3] + position_enc = position_enc[(None, ) * len(leading_shape)] + position_enc = position_enc.expand(*leading_shape, -1, -1, -1) + + # concat across channel dimension with input + source_im = torch.cat((source_im, position_enc), dim=-3) + + # make sure sample boundaries ensure crops are fully within the images + image_c, image_h, image_w = source_im.shape[-3:] + max_sample_h = image_h - crop_height + max_sample_w = image_w - crop_width + + # Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W]. + # Each gets @num_crops samples - typically this will just be the batch dimension (B), so + # we will sample [B, N] indices, but this supports having more than one leading dimension, + # or possibly no leading dimension. + # + # Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints + crop_inds_h = ( + max_sample_h * + torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() + crop_inds_w = ( + max_sample_w * + torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() + crop_inds = torch.cat( + (crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), + dim=-1) # shape [..., N, 2] + + crops = crop_image_from_indices( + images=source_im, + crop_indices=crop_inds, + crop_height=crop_height, + crop_width=crop_width, + ) + + return crops, crop_inds diff --git a/src/unifolm_wma/models/diffusion_head/vision/model_getter.py b/src/unifolm_wma/models/diffusion_head/vision/model_getter.py new file mode 100644 index 0000000..9c33ff0 --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/vision/model_getter.py @@ -0,0 +1,30 @@ +import torch +import torchvision + + +def get_resnet(name, weights=None, **kwargs): + """ + name: resnet18, resnet34, resnet50 + weights: "IMAGENET1K_V1", "r3m" + """ + # load r3m weights + if (weights == "r3m") or (weights == "R3M"): + return get_r3m(name=name, **kwargs) + + func = getattr(torchvision.models, name) + resnet = func(weights=weights, **kwargs) + resnet.fc = torch.nn.Identity() + return resnet + + +def get_r3m(name, **kwargs): + """ + name: resnet18, resnet34, resnet50 + """ + import r3m + r3m.device = 'cpu' + model = r3m.load_r3m(name) + r3m_model = model.module + resnet_model = r3m_model.convnet + resnet_model = resnet_model.to('cpu') + return resnet_model diff --git a/src/unifolm_wma/models/diffusion_head/vision/multi_image_obs_encoder.py b/src/unifolm_wma/models/diffusion_head/vision/multi_image_obs_encoder.py new file mode 100644 index 0000000..a46b0ed --- /dev/null +++ b/src/unifolm_wma/models/diffusion_head/vision/multi_image_obs_encoder.py @@ -0,0 +1,247 @@ +import copy +import torch +import torch.nn as nn +import torchvision +import json +import os + +from unifolm_wma.models.diffusion_head.vision.crop_randomizer import CropRandomizer +from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax +from unifolm_wma.models.diffusion_head.common.module_attr_mixin import ModuleAttrMixin +from unifolm_wma.models.diffusion_head.common.pytorch_util import dict_apply, replace_submodules +from unifolm_wma.utils.utils import instantiate_from_config +from einops import rearrange, repeat +from typing import Dict, Tuple, Union +from pathlib import Path + + +class MultiImageObsEncoder(ModuleAttrMixin): + + def __init__( + self, + rgb_model_config: Dict, + shape_meta_path: str | None = None, + resize_shape: Union[Tuple[int, int], Dict[str, tuple], + None] = None, + crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None, + random_crop: bool = True, + # replace BatchNorm with GroupNorm + use_group_norm: bool = False, + # use single rgb model for all rgb inputs + share_rgb_model: bool = False, + # renormalize rgb input with imagenet normalization + # assuming input in [0,1] + imagenet_norm: bool = False, + use_spatial_softmax=False, + spatial_softmax_kp=32, + use_dinoSiglip=False): + """ + Assumes rgb input: B,C,H,W + Assumes low_dim input: B,D + """ + super().__init__() + + if not shape_meta_path: + shape_meta_path = str(Path(os.getcwd()) / "configs/train/meta.json") + + with open(shape_meta_path, 'r') as file: + shape_meta = json.load(file) + + rgb_model = instantiate_from_config(rgb_model_config) + + rgb_keys = list() + low_dim_keys = list() + key_model_map = nn.ModuleDict() + key_transform_map = nn.ModuleDict() + key_shape_map = dict() + + # handle sharing vision backbone + if share_rgb_model: + assert isinstance(rgb_model, nn.Module) + key_model_map['rgb'] = rgb_model + + obs_shape_meta = shape_meta['obs'] + for key, attr in obs_shape_meta.items(): + shape = tuple(attr['shape']) + type = attr.get('type', 'low_dim') + key_shape_map[key] = shape + if type == 'rgb': + rgb_keys.append(key) + if not use_dinoSiglip: + # configure model for this key + this_model = None + if not share_rgb_model: + if isinstance(rgb_model, dict): + # have provided model for each key + this_model = rgb_model[key] + else: + assert isinstance(rgb_model, nn.Module) + # have a copy of the rgb model + this_model = copy.deepcopy(rgb_model) + + if this_model is not None: + if use_group_norm: + this_model = replace_submodules( + root_module=this_model, + predicate=lambda x: isinstance( + x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm( + num_groups=x.num_features // 16, + num_channels=x.num_features)) + key_model_map[key] = this_model + + # configure resize + input_shape = shape + this_resizer = nn.Identity() + if resize_shape is not None: + if isinstance(resize_shape, dict): + h, w = resize_shape[key] + else: + h, w = resize_shape + this_resizer = torchvision.transforms.Resize(size=(h, + w)) + input_shape = (shape[0], h, w) + + # configure randomizer + this_randomizer = nn.Identity() + if crop_shape is not None: + if isinstance(crop_shape, dict): + h, w = crop_shape[key] + else: + h, w = crop_shape + if random_crop: + this_randomizer = CropRandomizer( + input_shape=input_shape, + crop_height=h, + crop_width=w, + num_crops=1, + pos_enc=False) + else: + this_normalizer = torchvision.transforms.CenterCrop( + size=(h, w)) + # configure normalizer + this_normalizer = nn.Identity() + if imagenet_norm: + this_normalizer = torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + this_transform = nn.Sequential(this_resizer, + this_randomizer, + this_normalizer) + key_transform_map[key] = this_transform + else: + key_model_map[key] = rgb_model + elif type == 'low_dim': + low_dim_keys.append(key) + else: + raise RuntimeError(f"Unsupported obs type: {type}") + + rgb_keys = sorted(rgb_keys) + low_dim_keys = sorted(low_dim_keys) + + self.shape_meta = shape_meta + self.key_model_map = key_model_map + self.key_transform_map = key_transform_map + self.share_rgb_model = share_rgb_model + self.rgb_keys = rgb_keys + self.low_dim_keys = low_dim_keys + self.key_shape_map = key_shape_map + self.use_dinoSiglip = use_dinoSiglip + + ##NOTE add spatial softmax + self.use_spatial_softmax = use_spatial_softmax + if use_spatial_softmax and not use_dinoSiglip: + model = nn.Sequential( + key_model_map['image'].conv1, + key_model_map['image'].bn1, + key_model_map['image'].relu, + key_model_map['image'].maxpool, + key_model_map['image'].layer1, + key_model_map['image'].layer2, + key_model_map['image'].layer3, + key_model_map['image'].layer4, + ) + key_model_map['image'] = model + input_shape = self.output_shape(resnet_output_shape=True) + self.spatial_softmax = SpatialSoftmax(input_shape, + num_kp=spatial_softmax_kp) + + def forward(self, obs_dict, resnet_output_shape=False): + batch_size = None + features = list() + # process rgb input + if self.share_rgb_model: + # pass all rgb obs to rgb model + imgs = list() + for key in self.rgb_keys: + img = obs_dict[key] + if batch_size is None: + batch_size = img.shape[0] + else: + assert batch_size == img.shape[0] + assert img.shape[1:] == self.key_shape_map[key] + img = self.key_transform_map[key](img) + imgs.append(img) + # (N*B,C,H,W) + imgs = torch.cat(imgs, dim=0) + # (N*B,D) + feature = self.key_model_map['rgb'](imgs) + # (N,B,D) + feature = feature.reshape(-1, batch_size, *feature.shape[1:]) + # (B,N,D) + feature = torch.moveaxis(feature, 0, 1) + # (B,N*D) + feature = feature.reshape(batch_size, -1) + features.append(feature) + else: + # run each rgb obs to independent models + for key in self.rgb_keys: + img = obs_dict[key] + if batch_size is None: + batch_size = img.shape[0] + else: + assert batch_size == img.shape[0] + assert img.shape[1:] == self.key_shape_map[key] + if not self.use_dinoSiglip: + img = self.key_transform_map[key](img) + feature = self.key_model_map[key](img) + else: + feature = self.key_model_map[key](img)[:, :1, :] + + if resnet_output_shape: + return feature + if not self.use_dinoSiglip and self.use_spatial_softmax: + feature = self.spatial_softmax(feature) + feature = feature.reshape(batch_size, -1) + features.append(feature) + + # process lowdim input + for key in self.low_dim_keys: + data = obs_dict[key] + if batch_size is None: + batch_size = data.shape[0] + else: + assert batch_size == data.shape[0] + assert data.shape[1:] == self.key_shape_map[key] + features.append(data) + + # concatenate all features + result = torch.cat(features, dim=-1) + return result + + @torch.no_grad() + def output_shape(self, resnet_output_shape=False): + example_obs_dict = dict() + obs_shape_meta = self.shape_meta['obs'] + batch_size = 1 + for key, attr in obs_shape_meta.items(): + shape = tuple(attr['shape']) + this_obs = torch.zeros((batch_size, ) + shape, + dtype=self.dtype, + device=self.device) + example_obs_dict[key] = this_obs + example_output = self.forward(example_obs_dict, + resnet_output_shape=resnet_output_shape) + output_shape = example_output.shape[1:] + return output_shape diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py new file mode 100644 index 0000000..77602a1 --- /dev/null +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -0,0 +1,473 @@ +import numpy as np +import torch +import copy + +from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg +from unifolm_wma.utils.common import noise_like +from unifolm_wma.utils.common import extract_into_tensor +from tqdm import tqdm + + +class DDIMSampler(object): + + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.counter = 0 + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, + ddim_num_steps, + ddim_discretize="uniform", + ddim_eta=0., + verbose=True): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[ + 0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model + .device) + + if self.model.use_dynamic_rescale: + self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] + self.ddim_scale_arr_prev = torch.cat( + [self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', + to_torch(self.model.alphas_cumprod_prev)) + + # Calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', + to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', + to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # DDIM sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', + np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * + (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', + sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + schedule_verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + precision=None, + fs=None, + timestep_spacing='uniform', #uniform_trailing for starting from last timestep + guidance_rescale=0.0, + **kwargs): + + # Check condition bs + if conditioning is not None: + if isinstance(conditioning, dict): + try: + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + except: + cbs = conditioning[list( + conditioning.keys())[0]][0].shape[0] + + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, + ddim_discretize=timestep_spacing, + ddim_eta=eta, + verbose=schedule_verbose) + + # Make shape + if len(shape) == 3: + C, H, W = shape + size = (batch_size, C, H, W) + elif len(shape) == 4: + C, T, H, W = shape + size = (batch_size, C, T, H, W) + + samples, actions, states, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + verbose=verbose, + precision=precision, + fs=fs, + guidance_rescale=guidance_rescale, + **kwargs) + return samples, actions, states, intermediates + + @torch.no_grad() + def ddim_sampling(self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + verbose=True, + precision=None, + fs=None, + guidance_rescale=0.0, + **kwargs): + device = self.model.betas.device + dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action + dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state + + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + action = torch.randn((b, 16, self.model.agent_action_dim), + device=device) + state = torch.randn((b, 16, self.model.agent_state_dim), + device=device) + else: + img = x_T + action = torch.randn((b, 16, self.model.agent_action_dim), + device=device) + state = torch.randn((b, 16, self.model.agent_state_dim), + device=device) + + if precision is not None: + if precision == 16: + img = img.to(dtype=torch.float16) + action = action.to(dtype=torch.float16) + state = state.to(dtype=torch.float16) + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int( + min(timesteps / self.ddim_timesteps.shape[0], 1) * + self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = { + 'x_inter': [img], + 'pred_x0': [img], + 'x_inter_action': [action], + 'pred_x0_action': [action], + 'x_inter_state': [state], + 'pred_x0_state': [state], + } + time_range = reversed(range( + 0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[ + 0] + if verbose: + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + else: + iterator = time_range + + clean_cond = kwargs.pop("clean_cond", False) + + dp_ddim_scheduler_action.set_timesteps(len(timesteps)) + dp_ddim_scheduler_state.set_timesteps(len(timesteps)) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b, ), step, device=device, dtype=torch.long) + + # Use mask to blend noised original latent (img_orig) & new sampled latent (img) + if mask is not None: + assert x0 is not None + if clean_cond: + img_orig = x0 + else: + img_orig = self.model.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim( + img, + action, + state, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + x0=x0, + fs=fs, + guidance_rescale=guidance_rescale, + **kwargs) + + img, pred_x0, model_output_action, model_output_state = outs + + action = dp_ddim_scheduler_action.step( + model_output_action, + step, + action, + generator=None, + ).prev_sample + state = dp_ddim_scheduler_state.step( + model_output_state, + step, + state, + generator=None, + ).prev_sample + + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + intermediates['x_inter_action'].append(action) + intermediates['x_inter_state'].append(state) + + return img, action, state, intermediates + + @torch.no_grad() + def p_sample_ddim(self, + x, + x_action, + x_state, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + uc_type=None, + conditional_guidance_scale_temporal=None, + mask=None, + x0=None, + guidance_rescale=0.0, + **kwargs): + b, *_, device = *x.shape, x.device + if x.dim() == 5: + is_video = True + else: + is_video = False + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output, model_output_action, model_output_state = self.model.apply_model( + x, x_action, x_state, t, c, **kwargs) # unet denoiser + else: + # do_classifier_free_guidance + if isinstance(c, torch.Tensor) or isinstance(c, dict): + e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model( + x, x_action, x_state, t, c, **kwargs) + e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model( + x, x_action, x_state, t, unconditional_conditioning, + **kwargs) + else: + raise NotImplementedError + model_output = e_t_uncond + unconditional_guidance_scale * ( + e_t_cond - e_t_uncond) + model_output_action = e_t_uncond_action + unconditional_guidance_scale * ( + e_t_cond_action - e_t_uncond_action) + model_output_state = e_t_uncond_state + unconditional_guidance_scale * ( + e_t_cond_state - e_t_uncond_state) + + if guidance_rescale > 0.0: + model_output = rescale_noise_cfg( + model_output, e_t_cond, guidance_rescale=guidance_rescale) + model_output_action = rescale_noise_cfg( + model_output_action, + e_t_cond_action, + guidance_rescale=guidance_rescale) + model_output_state = rescale_noise_cfg( + model_output_state, + e_t_cond_state, + guidance_rescale=guidance_rescale) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, + **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + if is_video: + size = (b, 1, 1, 1, 1) + else: + size = (b, 1, 1, 1) + + a_t = torch.full(size, alphas[index], device=device) + a_prev = torch.full(size, alphas_prev[index], device=device) + sigma_t = torch.full(size, sigmas[index], device=device) + sqrt_one_minus_at = torch.full(size, + sqrt_one_minus_alphas[index], + device=device) + + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if self.model.use_dynamic_rescale: + scale_t = torch.full(size, + self.ddim_scale_arr[index], + device=device) + prev_scale_t = torch.full(size, + self.ddim_scale_arr_prev[index], + device=device) + rescale = (prev_scale_t / scale_t) + pred_x0 *= rescale + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + + noise = sigma_t * noise_like(x.shape, device, + repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + + return x_prev, pred_x0, model_output_action, model_output_state + + @torch.no_grad() + def decode(self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps + ) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0], ), + step, + device=x_latent.device, + dtype=torch.long) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * + noise) diff --git a/src/unifolm_wma/modules/__init__.py b/src/unifolm_wma/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py new file mode 100644 index 0000000..7b21317 --- /dev/null +++ b/src/unifolm_wma/modules/attention.py @@ -0,0 +1,806 @@ +import torch +import torch.nn.functional as F + +from torch import nn, einsum +from einops import rearrange, repeat +from functools import partial + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +from unifolm_wma.utils.common import ( + checkpoint, + exists, + default, +) +from unifolm_wma.utils.basics import zero_module + + +class RelativePosition(nn.Module): + """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """ + + def __init__(self, num_units, max_relative_position): + super().__init__() + self.num_units = num_units + self.max_relative_position = max_relative_position + self.embeddings_table = nn.Parameter( + torch.Tensor(max_relative_position * 2 + 1, num_units)) + nn.init.xavier_uniform_(self.embeddings_table) + + def forward(self, length_q, length_k): + device = self.embeddings_table.device + range_vec_q = torch.arange(length_q, device=device) + range_vec_k = torch.arange(length_k, device=device) + distance_mat = range_vec_k[None, :] - range_vec_q[:, None] + distance_mat_clipped = torch.clamp(distance_mat, + -self.max_relative_position, + self.max_relative_position) + final_mat = distance_mat_clipped + self.max_relative_position + final_mat = final_mat.long() + embeddings = self.embeddings_table[final_mat] + return embeddings + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0., + relative_position=False, + temporal_length=None, + video_length=None, + agent_state_context_len=2, + agent_action_context_len=16, + image_cross_attention=False, + image_cross_attention_scale=1.0, + agent_state_cross_attention_scale=1.0, + agent_action_cross_attention_scale=1.0, + cross_attention_scale_learnable=False, + text_context_len=77): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + self.dim_head = dim_head + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout)) + + self.relative_position = relative_position + if self.relative_position: + assert (temporal_length is not None) + self.relative_position_k = RelativePosition( + num_units=dim_head, max_relative_position=temporal_length) + self.relative_position_v = RelativePosition( + num_units=dim_head, max_relative_position=temporal_length) + else: + ## only used for spatial attention, while NOT for temporal attention + if XFORMERS_IS_AVAILBLE and temporal_length is None: + self.forward = self.efficient_forward + + self.video_length = video_length + self.image_cross_attention = image_cross_attention + self.image_cross_attention_scale = image_cross_attention_scale + self.agent_state_cross_attention_scale = agent_state_cross_attention_scale + self.agent_action_cross_attention_scale = agent_action_cross_attention_scale + self.text_context_len = text_context_len + self.agent_state_context_len = agent_state_context_len + self.agent_action_context_len = agent_action_context_len + self.cross_attention_scale_learnable = cross_attention_scale_learnable + if self.image_cross_attention: + self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_k_as = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_as = nn.Linear(context_dim, inner_dim, bias=False) + self.to_k_aa = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_aa = nn.Linear(context_dim, inner_dim, bias=False) + if cross_attention_scale_learnable: + self.register_parameter('alpha_ctx', + nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_cas', + nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_caa', + nn.Parameter(torch.tensor(0.))) + + def forward(self, x, context=None, mask=None): + spatial_self_attn = (context is None) + k_ip, v_ip, out_ip = None, None, None + k_as, v_as, out_as = None, None, None + k_aa, v_aa, out_aa = None, None, None + + h = self.heads + q = self.to_q(x) + context = default(context, x) + + if self.image_cross_attention and not spatial_self_attn: + assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..." + context_agent_state = context[:, :self.agent_state_context_len, :] + context_agent_action = context[:, + self.agent_state_context_len:self. + agent_state_context_len + + self.agent_action_context_len, :] + context_ins = context[:, self.agent_state_context_len + + self.agent_action_context_len:self. + agent_state_context_len + + self.agent_action_context_len + + self.text_context_len, :] + context_image = context[:, self.agent_state_context_len + + self.agent_action_context_len + + self.text_context_len:, :] + + k = self.to_k(context_ins) + v = self.to_v(context_ins) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + k_as = self.to_k_as(context_agent_state) + v_as = self.to_v_as(context_agent_state) + k_aa = self.to_k_aa(context_agent_action) + v_aa = self.to_v_aa(context_agent_action) + else: + if not spatial_self_attn: + context = context[:, :self.text_context_len, :] + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + if self.relative_position: + len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1] + k2 = self.relative_position_k(len_q, len_k) + sim2 = einsum('b t d, t s d -> b t s', q, + k2) * self.scale # TODO check + sim += sim2 + del k + + if exists(mask): + ## feasible for causal attention mask only + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b i j -> (b h) i j', h=h) + sim.masked_fill_(~(mask > 0.5), max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + if self.relative_position: + v2 = self.relative_position_v(len_q, len_v) + out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check + out += out2 + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + if k_ip is not None and k_as is not None and k_aa is not None: + ## for image cross-attention + k_ip, v_ip = map( + lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (k_ip, v_ip)) + sim_ip = torch.einsum('b i d, b j d -> b i j', q, + k_ip) * self.scale + del k_ip + sim_ip = sim_ip.softmax(dim=-1) + out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) + out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) + + ## for agent state cross-attention + k_as, v_as = map( + lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (k_as, v_as)) + sim_as = torch.einsum('b i d, b j d -> b i j', q, + k_as) * self.scale + del k_as + sim_as = sim_as.softmax(dim=-1) + out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as) + out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h) + + ## for agent action cross-attention + k_aa, v_aa = map( + lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (k_aa, v_aa)) + sim_aa = torch.einsum('b i d, b j d -> b i j', q, + k_aa) * self.scale + del k_aa + sim_aa = sim_aa.softmax(dim=-1) + out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa) + out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h) + + if out_ip is not None and out_as is not None and out_aa is not None: + if self.cross_attention_scale_learnable: + out = out + \ + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \ + self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \ + self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1) + else: + out = out + \ + self.image_cross_attention_scale * out_ip + \ + self.agent_state_cross_attention_scale * out_as + \ + self.agent_action_cross_attention_scale * out_aa + + return self.to_out(out) + + def efficient_forward(self, x, context=None, mask=None): + spatial_self_attn = (context is None) + k, v, out = None, None, None + k_ip, v_ip, out_ip = None, None, None + k_as, v_as, out_as = None, None, None + k_aa, v_aa, out_aa = None, None, None + + q = self.to_q(x) + context = default(context, x) + + if self.image_cross_attention and not spatial_self_attn: + if context.shape[1] == self.text_context_len + self.video_length: + context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :] + k = self.to_k(context) + v = self.to_v(context) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length: + context_agent_state = context[:, :self.agent_state_context_len, :] + context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :] + context_image = context[:, self.agent_state_context_len+self.text_context_len:, :] + k = self.to_k(context_ins) + v = self.to_v(context_ins) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + k_as = self.to_k_as(context_agent_state) + v_as = self.to_v_as(context_agent_state) + else: + context_agent_state = context[:, :self.agent_state_context_len, :] + context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :] + context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :] + context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :] + + k = self.to_k(context_ins) + v = self.to_v(context_ins) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + k_as = self.to_k_as(context_agent_state) + v_as = self.to_v_as(context_agent_state) + k_aa = self.to_k_aa(context_agent_action) + v_aa = self.to_v_aa(context_agent_action) + + attn_mask_aa = self._get_attn_mask_aa(x.shape[0], + q.shape[1], + k_aa.shape[1], + block_size=16).to(k_aa.device) + else: + if not spatial_self_attn: + assert 1 > 2, ">>> ERROR: you should never go into here ..." + context = context[:, :self.text_context_len, :] + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous() + if k is not None: + k, v = map( + lambda t: t.unsqueeze(3).reshape(b, t.shape[ + 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous(), + (k, v), + ) + out = xformers.ops.memory_efficient_attention(q, + k, + v, + attn_bias=None, + op=None) + out = (out.unsqueeze(0).reshape( + b, self.heads, out.shape[1], + self.dim_head).permute(0, 2, 1, + 3).reshape(b, out.shape[1], + self.heads * self.dim_head)) + + if k_ip is not None: + # For image cross-attention + k_ip, v_ip = map( + lambda t: t.unsqueeze(3).reshape(b, t.shape[ + 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous( + ), + (k_ip, v_ip), + ) + out_ip = xformers.ops.memory_efficient_attention(q, + k_ip, + v_ip, + attn_bias=None, + op=None) + out_ip = (out_ip.unsqueeze(0).reshape( + b, self.heads, out_ip.shape[1], + self.dim_head).permute(0, 2, 1, + 3).reshape(b, out_ip.shape[1], + self.heads * self.dim_head)) + + if k_as is not None: + # For agent state cross-attention + k_as, v_as = map( + lambda t: t.unsqueeze(3).reshape(b, t.shape[ + 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous( + ), + (k_as, v_as), + ) + out_as = xformers.ops.memory_efficient_attention(q, + k_as, + v_as, + attn_bias=None, + op=None) + out_as = (out_as.unsqueeze(0).reshape( + b, self.heads, out_as.shape[1], + self.dim_head).permute(0, 2, 1, + 3).reshape(b, out_as.shape[1], + self.heads * self.dim_head)) + if k_aa is not None: + # For agent action cross-attention + k_aa, v_aa = map( + lambda t: t.unsqueeze(3).reshape(b, t.shape[ + 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous( + ), + (k_aa, v_aa), + ) + + attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape( + b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2]) + attn_mask_aa = attn_mask_aa.to(q.dtype) + + out_aa = xformers.ops.memory_efficient_attention( + q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None) + + out_aa = (out_aa.unsqueeze(0).reshape( + b, self.heads, out_aa.shape[1], + self.dim_head).permute(0, 2, 1, + 3).reshape(b, out_aa.shape[1], + self.heads * self.dim_head)) + if exists(mask): + raise NotImplementedError + + out = 0.0 if out is None else out + out_ip = 0.0 if out_ip is None else out_ip + out_as = 0.0 if out_as is None else out_as + out_aa = 0.0 if out_aa is None else out_aa + + if self.cross_attention_scale_learnable: + out = out + \ + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \ + self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \ + self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1) + + else: + out = out + \ + self.image_cross_attention_scale * out_ip + \ + self.agent_state_cross_attention_scale * out_as + \ + self.agent_action_cross_attention_scale * out_aa + + return self.to_out(out) + + def _get_attn_mask_aa(self, b, l1, l2, block_size=16): + num_token = l2 // block_size + start_positions = ((torch.arange(b) % block_size) + 1) * num_token + col_indices = torch.arange(l2) + mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1) + mask = mask_2d.unsqueeze(1).expand(b, l1, l2) + attn_mask = torch.zeros_like(mask, dtype=torch.float) + attn_mask[mask] = float('-inf') + return attn_mask + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attention_cls=None, + video_length=None, + agent_state_context_len=2, + agent_action_context_len=16, + image_cross_attention=False, + image_cross_attention_scale=1.0, + cross_attention_scale_learnable=False, + text_context_len=77): + super().__init__() + attn_cls = CrossAttention if attention_cls is None else attention_cls + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + video_length=video_length, + agent_state_context_len=agent_state_context_len, + agent_action_context_len=agent_action_context_len, + image_cross_attention=image_cross_attention, + image_cross_attention_scale=image_cross_attention_scale, + cross_attention_scale_learnable=cross_attention_scale_learnable, + text_context_len=text_context_len) + self.image_cross_attention = image_cross_attention + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, mask=None, **kwargs): + # implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments + input_tuple = ( + x, + ) # should not be (x), otherwise *input_tuple will decouple x into multiple arguments + if context is not None: + input_tuple = (x, context) + if mask is not None: + forward_mask = partial(self._forward, mask=mask) + return checkpoint(forward_mask, (x, ), self.parameters(), + self.checkpoint) + return checkpoint(self._forward, input_tuple, self.parameters(), + self.checkpoint) + + def _forward(self, x, context=None, mask=None): + x = self.attn1(self.norm1(x), + context=context if self.disable_self_attn else None, + mask=mask) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data in spatial axis. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + use_checkpoint=True, + disable_self_attn=False, + use_linear=False, + video_length=None, + agent_state_context_len=2, + agent_action_context_len=16, + image_cross_attention=False, + cross_attention_scale_learnable=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + attention_cls = None + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + attention_cls=attention_cls, + video_length=video_length, + agent_state_context_len=agent_state_context_len, + agent_action_context_len=agent_action_context_len, + image_cross_attention=image_cross_attention, + cross_attention_scale_learnable=cross_attention_scale_learnable, + ) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None, **kwargs): + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context, **kwargs) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data in temporal axis. + First, reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + use_checkpoint=True, + use_linear=False, + only_self_att=True, + causal_attention=False, + causal_block_size=1, + relative_position=False, + temporal_length=None): + super().__init__() + self.only_self_att = only_self_att + self.relative_position = relative_position + self.causal_attention = causal_attention + self.causal_block_size = causal_block_size + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + if relative_position: + assert (temporal_length is not None) + attention_cls = partial(CrossAttention, + relative_position=True, + temporal_length=temporal_length) + else: + attention_cls = partial(CrossAttention, + temporal_length=temporal_length) + if self.causal_attention: + assert (temporal_length is not None) + self.mask = torch.tril( + torch.ones([1, temporal_length, temporal_length])) + + if self.only_self_att: + context_dim = None + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock(inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + attention_cls=attention_cls, + checkpoint=use_checkpoint) + for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + b, c, t, h, w = x.shape + x_in = x + x = self.norm(x) + x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'bhw c t -> bhw t c').contiguous() + if self.use_linear: + x = self.proj_in(x) + + temp_mask = None + if self.causal_attention: + # Slice the from mask map + temp_mask = self.mask[:, :t, :t].to(x.device) + + if temp_mask is not None: + mask = temp_mask.to(x.device) + mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b * h * w) + else: + mask = None + + if self.only_self_att: + # NOTE: if no context is given, cross-attention defaults to self-attention + for i, block in enumerate(self.transformer_blocks): + x = block(x, mask=mask) + x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() + context = rearrange(context, '(b t) l con -> b t l con', + t=t).contiguous() + for i, block in enumerate(self.transformer_blocks): + # Calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_j = repeat(context[j], + 't l con -> (t r) l con', + r=(h * w) // t, + t=t).contiguous() + # Note: causal mask will not applied in cross-attention case + x[j] = block(x[j], context=context_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, + w=w).contiguous() + + return x + x_in + + +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +class LinearAttention(nn.Module): + + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, + 'b (qkv heads c) h w -> qkv b heads c (h w)', + heads=self.heads, + qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, + 'b heads c (h w) -> b (heads c) h w', + heads=self.heads, + h=h, + w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # Compute attention + b, c, h, w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # Attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x + h_ diff --git a/src/unifolm_wma/modules/encoders/condition.py b/src/unifolm_wma/modules/encoders/condition.py new file mode 100644 index 0000000..44f0fdc --- /dev/null +++ b/src/unifolm_wma/modules/encoders/condition.py @@ -0,0 +1,630 @@ +import torch +import torch.nn as nn +import kornia +import open_clip +import math + +from torch.utils.checkpoint import checkpoint +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +from unifolm_wma.utils.common import autocast +from unifolm_wma.utils.utils import count_params +from unifolm_wma.modules.encoders.resampler import reshape_tensor + + +class AbstractEncoder(nn.Module): + + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - + 1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs, ), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__(self, + version="google/t5-v1_1-xxl", + device="cuda", + max_length=77, + freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = ["last", "pooled", "hidden"] + + def __init__(self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, + output_hidden_states=self.layer == "hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class ClipImageEmbedder(nn.Module): + + def __init__(self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=True, + ucg_rate=0.): + super().__init__() + from clip import load as load_clip + self.model, _ = load_clip(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', + torch.Tensor([0.48145466, 0.4578275, 0.40821073]), + persistent=False) + self.register_buffer('std', + torch.Tensor([0.26862954, 0.26130258, + 0.27577711]), + persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', + align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # re-normalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x, no_dropout=False): + # x is assumed to be in range [-1,1] + out = self.model.encode_image(self.preprocess(x)) + out = out.to(x.dtype) + if self.ucg_rate > 0. and not no_dropout: + out = torch.bernoulli( + (1. - self.ucg_rate) * + torch.ones(out.shape[0], device=out.device))[:, None] * out + return out + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + # "pooled", + "last", + "penultimate" + ] + + def __init__(self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize( + text) ## all clip models use 77 as context length + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting( + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__(self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="pooled", + antialias=True, + ucg_rate=0.): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device('cpu'), + pretrained=version, + ) + del model.transformer + self.model = model + # self.mapper = torch.nn.Linear(1280, 1024) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer('mean', + torch.Tensor([0.48145466, 0.4578275, 0.40821073]), + persistent=False) + self.register_buffer('std', + torch.Tensor([0.26862954, 0.26130258, + 0.27577711]), + persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', + align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + @autocast + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + if self.ucg_rate > 0. and not no_dropout: + z = torch.bernoulli( + (1. - self.ucg_rate) * + torch.ones(z.shape[0], device=z.device))[:, None] * z + return z + + def encode_with_vision_transformer(self, img): + img = self.preprocess(img) + x = self.model.visual(img) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__(self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + freeze=True, + layer="pooled", + antialias=True): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device('cpu'), + pretrained=version, + ) + del model.transformer + self.model = model + self.device = device + + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer('mean', + torch.Tensor([0.48145466, 0.4578275, 0.40821073]), + persistent=False) + self.register_buffer('std', + torch.Tensor([0.26862954, 0.26130258, + 0.27577711]), + persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', + align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + def forward(self, image, no_dropout=False): + ## image: b c h w + z = self.encode_with_vision_transformer(image) + return z + + def encode_with_vision_transformer(self, x): + x = self.preprocess(x) + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.model.visual.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], + self.model.visual.grid_size[0], + self.model.visual.patch_size[0], + self.model.visual.grid_size[1], + self.model.visual.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape( + x.shape[0], self.model.visual.grid_size[0] * + self.model.visual.grid_size[1], -1) + x = self.model.visual.patchnorm_pre_ln(x) + x = self.model.visual.conv1(x) + else: + x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat([ + self.model.visual.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x + ], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.model.visual.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.model.visual.patch_dropout(x) + x = self.model.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.model.visual.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + return x + + +class FrozenCLIPT5Encoder(AbstractEncoder): + + def __init__(self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, + device, + max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, + device, + max_length=t5_max_length) + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." + ) + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + +class LinearProjector(nn.Module): + + def __init__(self, input_dim: int, output_dim: int) -> None: + super().__init__() + self.projector = nn.Linear(input_dim, output_dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.projector(x) + + +class MLPProjector(nn.Module): + + def __init__(self, + input_dim: int, + output_dim: int, + mlp_type: str = "gelu-mlp") -> None: + super().__init__() + if mlp_type == "gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(input_dim, output_dim, bias=True), + nn.GELU(approximate='tanh'), + nn.Linear(output_dim, output_dim, bias=True), + ) + elif mlp_type == "silu-mlp": + self.projector = nn.Sequential( + nn.Linear(input_dim, output_dim, bias=True), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=True), + ) + else: + raise ValueError( + f"Projector with `{mlp_type = }` is not supported!") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.projector(x) + + +class PerceiverAttention(nn.Module): + + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose( + -2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +def FeedForward(dim, mult=4, ffd_type="gelu-ffd"): + inner_dim = int(dim * mult) + if ffd_type == "gelu-ffd": + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(approximate='tanh'), + nn.Linear(inner_dim, dim, bias=False), + ) + elif ffd_type == "silu-ffd": + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.SiLU(), + nn.Linear(inner_dim, dim, bias=False), + ) + else: + raise ValueError(f"Projector with `{mlp_type = }` is not supported!") + + +class SATokenProjector(nn.Module): + + def __init__(self, + dim=1024, + depth=1, + dim_head=64, + heads=16, + num_queries=16, + output_dim=1024, + ff_mult=4, + chunk_size=None): + super().__init__() + self.num_queries = num_queries + self.chunk_size = chunk_size + + if chunk_size is not None: + num_queries = num_queries * chunk_size + + self.latents = nn.Parameter( + torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, + heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + latents = self.norm_out(latents) + return latents diff --git a/src/unifolm_wma/modules/encoders/resampler.py b/src/unifolm_wma/modules/encoders/resampler.py new file mode 100644 index 0000000..47e7364 --- /dev/null +++ b/src/unifolm_wma/modules/encoders/resampler.py @@ -0,0 +1,153 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py +# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py +import math +import torch +import torch.nn as nn + + +class ImageProjModel(nn.Module): + """Projection Model""" + + def __init__(self, + cross_attention_dim=1024, + clip_embeddings_dim=1024, + clip_extra_context_tokens=4): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = nn.Linear( + clip_embeddings_dim, + self.clip_extra_context_tokens * cross_attention_dim) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + #embeds = image_embeds + embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose( + -2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + video_length=None, # using frame-wise version or not + ): + super().__init__() + ## queries for a single frame / image + self.num_queries = num_queries + self.video_length = video_length + + ## queries for each frame + if video_length is not None: + num_queries = num_queries * video_length + + self.latents = nn.Parameter( + torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, + heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + latents = self.norm_out(latents) + + return latents diff --git a/src/unifolm_wma/modules/networks/ae_modules.py b/src/unifolm_wma/modules/networks/ae_modules.py new file mode 100644 index 0000000..2ec124d --- /dev/null +++ b/src/unifolm_wma/modules/networks/ae_modules.py @@ -0,0 +1,1005 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import numpy as np +import torch.nn as nn + +from einops import rearrange +from unifolm_wma.modules.attention import LinearAttention +from unifolm_wma.utils.utils import instantiate_from_config + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True) + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) # bcl + q = q.permute(0, 2, 1) # bcl -> blc l=hw + k = k.reshape(b, c, h * w) # bcl + + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", + "none"], f'attn_type {attn_type} unknown' + #print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + self.in_channels = in_channels + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, + scale_factor=2.0, + mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class ResnetBlock(nn.Module): + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class Model(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock(in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], + dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2 * + z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # print(f'encoder-input={x.shape}') + # downsampling + hs = [self.conv_in(x)] + # print(f'encoder-conv in feat={hs[0].shape}') + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + # print(f'encoder-down feat={h.shape}') + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + # print(f'encoder-downsample (input)={hs[-1].shape}') + hs.append(self.down[i_level].downsample(hs[-1])) + # print(f'encoder-downsample (output)={hs[-1].shape}') + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + # print(f'encoder-mid1 feat={h.shape}') + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # print(f'encoder-mid2 feat={h.shape}') + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # print(f'end feat={h.shape}') + return h + + +class Decoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1, ) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("AE working on z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # print(f'decoder-input={z.shape}') + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + # print(f'decoder-conv in feat={h.shape}') + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # print(f'decoder-mid feat={h.shape}') + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + # print(f'decoder-up feat={h.shape}') + if i_level != 0: + h = self.up[i_level].upsample(h) + # print(f'decoder-upsample feat={h.shape}') + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + # print(f'decoder-conv_out feat={h.shape}') + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True) + ]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2**(self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + + def __init__(self, + factor, + in_channels, + mid_channels, + out_channels, + depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ + ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth) + ]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ + ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth) + ]) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=(int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + + def __init__(self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + + def __init__(self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + + def __init__(self, + in_size, + out_size, + in_channels, + out_channels, + ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1. + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler(factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate(x, + mode=self.mode, + align_corners=False, + scale_factor=scale_factor) + return x + + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d(in_channels, + n_channels, + kernel_size=3, + stride=1, + padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock(in_channels=ch_in, + out_channels=m * n_channels, + dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, 'b c h w -> b (h w) c') + return z diff --git a/src/unifolm_wma/modules/networks/wma_model.py b/src/unifolm_wma/modules/networks/wma_model.py new file mode 100644 index 0000000..e1b4838 --- /dev/null +++ b/src/unifolm_wma/modules/networks/wma_model.py @@ -0,0 +1,848 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch import Tensor +from functools import partial +from abc import abstractmethod +from einops import rearrange +from omegaconf import OmegaConf +from typing import Optional, Sequence, Any, Tuple, Union, List, Dict +from collections.abc import Mapping, Iterable, Callable + +from unifolm_wma.utils.diffusion import timestep_embedding +from unifolm_wma.utils.common import checkpoint +from unifolm_wma.utils.basics import (zero_module, conv_nd, linear, + avg_pool_nd, normalization) +from unifolm_wma.modules.attention import SpatialTransformer, TemporalTransformer +from unifolm_wma.utils.utils import instantiate_from_config + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, batch_size=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb, batch_size=batch_size) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + elif isinstance(layer, TemporalTransformer): + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size) + x = layer(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + else: + x = layer(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, + self.channels, + self.out_channels, + 3, + padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + :param use_temporal_conv: if True, use the temporal convolution. + :param use_image_dataset: if True, the temporal parameters will not be optimized. + """ + + def __init__(self, + channels, + emb_channels, + dropout, + out_channels=None, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + use_conv=False, + up=False, + down=False, + use_temporal_conv=False, + tempspatial_aware=False): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, + channels, + self.out_channels, + 3, + padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, + 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock( + self.out_channels, + self.out_channels, + dropout=0.1, + spatial_aware=tempspatial_aware) + + def forward(self, x, emb, batch_size=None): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + input_tuple = (x, emb) + if batch_size: + forward_batchsize = partial(self._forward, batch_size=batch_size) + return checkpoint(forward_batchsize, input_tuple, + self.parameters(), self.use_checkpoint) + return checkpoint(self._forward, input_tuple, self.parameters(), + self.use_checkpoint) + + def _forward(self, x, emb, batch_size=None): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv and batch_size: + h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c t h w -> (b t) c h w') + return h + + +class TemporalConvBlock(nn.Module): + """ + Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py + """ + + def __init__(self, + in_channels, + out_channels=None, + dropout=0.0, + spatial_aware=False): + super(TemporalConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + self.in_channels = in_channels + self.out_channels = out_channels + th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1) + th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0) + tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3) + tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1) + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_channels), nn.SiLU(), + nn.Conv3d(in_channels, + out_channels, + th_kernel_shape, + padding=th_padding_shape)) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_channels, + in_channels, + tw_kernel_shape, + padding=tw_padding_shape)) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_channels, + in_channels, + th_kernel_shape, + padding=th_padding_shape)) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_channels, + in_channels, + tw_kernel_shape, + padding=tw_padding_shape)) + + # Zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + return identity + x + + +class WMAModel(nn.Module): + """ + The full World-Model-Action model. + """ + + def __init__(self, + in_channels: int, + model_channels: int, + out_channels: int, + num_res_blocks: int, + attention_resolutions: Sequence[int], + dropout: float = 0.0, + channel_mult: Sequence[int] = (1, 2, 4, 8), + conv_resample: bool = True, + dims: int = 2, + context_dim: int | None = None, + use_scale_shift_norm: bool = False, + resblock_updown: bool = False, + num_heads: int = -1, + num_head_channels: int = -1, + transformer_depth: int = 1, + use_linear: bool = False, + use_checkpoint: bool = False, + temporal_conv: bool = False, + tempspatial_aware: bool = False, + temporal_attention: bool = True, + use_relative_position: bool = True, + use_causal_attention: bool = False, + temporal_length: int | None = None, + use_fp16: bool = False, + addition_attention: bool = False, + temporal_selfatt_only: bool = True, + image_cross_attention: bool = False, + cross_attention_scale_learnable: bool = False, + default_fs: int = 4, + fs_condition: bool = False, + n_obs_steps: int = 1, + num_stem_token: int = 1, + unet_head_config: OmegaConf | None = None, + stem_process_config: OmegaConf | None = None, + base_model_gen_only: bool = False): + """ + Initialize the World-Model-Action network. + + Args: + in_channels: Number of input channels to the backbone. + model_channels: Base channel width for the UNet/backbone. + out_channels: Number of output channels. + num_res_blocks: Number of residual blocks per resolution stage. + attention_resolutions: Resolutions at which to enable attention. + dropout: Dropout probability used inside residual/attention blocks. + channel_mult: Multipliers for channels at each resolution level. + conv_resample: If True, use convolutional resampling for up/down sampling. + dims: Spatial dimensionality of the backbone (1/2/3). + context_dim: Optional context embedding dimension (for cross-attention). + use_scale_shift_norm: Enable scale-shift (FiLM-style) normalization in blocks. + resblock_updown: Use residual blocks for up/down sampling (instead of plain conv). + num_heads: Number of attention heads (if >= 0). If -1, derive from num_head_channels. + num_head_channels: Channels per attention head (if >= 0). If -1, derive from num_heads. + transformer_depth: Number of transformer/attention blocks per stage. + use_linear: Use linear attention variants where applicable. + use_checkpoint: Enable gradient checkpointing in blocks to save memory. + temporal_conv: Include temporal convolution along the time dimension. + tempspatial_aware: If True, use time–space aware blocks. + temporal_attention: Enable temporal self-attention. + use_relative_position: Use relative position encodings in attention. + use_causal_attention: Use causal (uni-directional) attention along time. + temporal_length: Optional maximum temporal length expected by the model. + use_fp16: Enable half-precision layers/normalization where supported. + addition_attention: Add auxiliary attention modules. + temporal_selfatt_only: Restrict attention to temporal-only (no spatial) if True. + image_cross_attention: Enable cross-attention with image embeddings. + cross_attention_scale_learnable: Make cross-attention scaling a learnable parameter. + default_fs: Default frame-stride / fps. + fs_condition: If True, condition on frame-stride/fps features. + n_obs_steps: Number of observed steps used in conditioning heads. + num_stem_token: Number of stem tokens for action tokenization. + unet_head_config: OmegaConf for UNet heads (e.g., action/state heads). + stem_process_config: OmegaConf for stem/preprocessor module. + base_model_gen_only: Perform the generation using the base model with out action and state outputs. + """ + + super(WMAModel, self).__init__() + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.temporal_attention = temporal_attention + time_embed_dim = model_channels * 4 + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + temporal_self_att_only = True + self.addition_attention = addition_attention + self.temporal_length = temporal_length + self.image_cross_attention = image_cross_attention + self.cross_attention_scale_learnable = cross_attention_scale_learnable + self.default_fs = default_fs + self.fs_condition = fs_condition + self.n_obs_steps = n_obs_steps + self.num_stem_token = num_stem_token + self.base_model_gen_only = base_model_gen_only + + # Time embedding blocks + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + if fs_condition: + self.fps_embedding = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + # Input Block + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1)) + ]) + if self.addition_attention: + self.init_attn = TimestepEmbedSequential( + TemporalTransformer(model_channels, + n_heads=8, + d_head=num_head_channels, + depth=transformer_depth, + context_dim=context_dim, + use_checkpoint=use_checkpoint, + only_self_att=temporal_selfatt_only, + causal_attention=False, + relative_position=use_relative_position, + temporal_length=temporal_length)) + + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock(ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + disable_self_attn=False, + video_length=temporal_length, + agent_state_context_len=self.n_obs_steps, + agent_action_context_len=self.temporal_length * + num_stem_token, + image_cross_attention=self.image_cross_attention, + cross_attention_scale_learnable=self. + cross_attention_scale_learnable, + )) + if self.temporal_attention: + layers.append( + TemporalTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + only_self_att=temporal_self_att_only, + causal_attention=use_causal_attention, + relative_position=use_relative_position, + temporal_length=temporal_length)) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock(ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True) + if resblock_updown else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers = [ + ResBlock(ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv), + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + disable_self_attn=False, + video_length=temporal_length, + agent_state_context_len=self.n_obs_steps, + agent_action_context_len=self.temporal_length * num_stem_token, + image_cross_attention=self.image_cross_attention, + cross_attention_scale_learnable=self. + cross_attention_scale_learnable) + ] + if self.temporal_attention: + layers.append( + TemporalTransformer(ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + only_self_att=temporal_self_att_only, + causal_attention=use_causal_attention, + relative_position=use_relative_position, + temporal_length=temporal_length)) + layers.append( + ResBlock(ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv)) + + # Middle Block + self.middle_block = TimestepEmbedSequential(*layers) + + # Output Block + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock(ch + ich, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + tempspatial_aware=tempspatial_aware, + use_temporal_conv=temporal_conv) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + disable_self_attn=False, + video_length=temporal_length, + agent_state_context_len=self.n_obs_steps, + image_cross_attention=self.image_cross_attention, + cross_attention_scale_learnable=self. + cross_attention_scale_learnable)) + if self.temporal_attention: + layers.append( + TemporalTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear, + use_checkpoint=use_checkpoint, + only_self_att=temporal_self_att_only, + causal_attention=use_causal_attention, + relative_position=use_relative_position, + temporal_length=temporal_length)) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock(ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True) + if resblock_updown else Upsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module( + conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + # Action and state prediction unet + unet_head_config['params']['context_dims'] = [ + mult * model_channels for mult in channel_mult + ] + self.action_unet = instantiate_from_config(unet_head_config) + self.state_unet = instantiate_from_config(unet_head_config) + + # Initialize action token_projector + self.action_token_projector = instantiate_from_config( + stem_process_config) + + def forward(self, + x: Tensor, + x_action: Tensor, + x_state: Tensor, + timesteps: Tensor, + context: Tensor | None = None, + context_action: Tensor | None = None, + features_adapter: Any = None, + fs: Tensor | None = None, + **kwargs) -> Tensor | tuple[Tensor, ...]: + + """ + Forward pass of the World-Model-Action backbone. + + Args: + x: Input tensor (latent video), shape (B, C,...). + x_action: action stream input. + x_state: state stream input. + timesteps: Diffusion timesteps, shape (B,) or scalar Tensor. + context: conditioning context for cross-attention. + context_action: conditioning context specific to action/state (implementation-specific). + features_adapter: module or dict to adapt intermediate features. + fs: frame-stride / fps conditioning. + + Returns: + Tuple of Tensors for predictions: + + """ + + b, _, t, _, _ = x.shape + t_emb = timestep_embedding(timesteps, + self.model_channels, + repeat_only=False).type(x.dtype) + emb = self.time_embed(t_emb) + + bt, l_context, _ = context.shape + if self.base_model_gen_only: + assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE + else: + if l_context == self.n_obs_steps + 77 + t * 16: + context_agent_state = context[:, :self.n_obs_steps] + context_text = context[:, self.n_obs_steps:self.n_obs_steps + + 77, :] + context_img = context[:, self.n_obs_steps + 77:, :] + context_agent_state = context_agent_state.repeat_interleave( + repeats=t, dim=0) + context_text = context_text.repeat_interleave(repeats=t, dim=0) + context_img = rearrange(context_img, + 'b (t l) c -> (b t) l c', + t=t) + context = torch.cat( + [context_agent_state, context_text, context_img], dim=1) + elif l_context == self.n_obs_steps + 16 + 77 + t * 16: + context_agent_state = context[:, :self.n_obs_steps] + context_agent_action = context[:, self. + n_obs_steps:self.n_obs_steps + + 16, :] + context_agent_action = rearrange( + context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d') + context_agent_action = self.action_token_projector( + context_agent_action) + context_agent_action = rearrange(context_agent_action, + '(b o) l d -> b o l d', + o=t) + context_agent_action = rearrange(context_agent_action, + 'b o (t l) d -> b o t l d', + t=t) + context_agent_action = context_agent_action.permute( + 0, 2, 1, 3, 4) + context_agent_action = rearrange(context_agent_action, + 'b t o l d -> (b t) (o l) d') + + context_text = context[:, self.n_obs_steps + + 16:self.n_obs_steps + 16 + 77, :] + context_text = context_text.repeat_interleave(repeats=t, dim=0) + + context_img = context[:, self.n_obs_steps + 16 + 77:, :] + context_img = rearrange(context_img, + 'b (t l) c -> (b t) l c', + t=t) + context_agent_state = context_agent_state.repeat_interleave( + repeats=t, dim=0) + context = torch.cat([ + context_agent_state, context_agent_action, context_text, + context_img + ], + dim=1) + + emb = emb.repeat_interleave(repeats=t, dim=0) + + x = rearrange(x, 'b c t h w -> (b t) c h w') + + # Combine emb + if self.fs_condition: + if fs is None: + fs = torch.tensor([self.default_fs] * b, + dtype=torch.long, + device=x.device) + fs_emb = timestep_embedding(fs, + self.model_channels, + repeat_only=False).type(x.dtype) + + fs_embed = self.fps_embedding(fs_emb) + fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0) + emb = emb + fs_embed + + h = x.type(self.dtype) + adapter_idx = 0 + hs = [] + hs_a = [] + for id, module in enumerate(self.input_blocks): + h = module(h, emb, context=context, batch_size=b) + if id == 0 and self.addition_attention: + h = self.init_attn(h, emb, context=context, batch_size=b) + # plug-in adapter features + if ((id + 1) % 3 == 0) and features_adapter is not None: + h = h + features_adapter[adapter_idx] + adapter_idx += 1 + if id != 0: + if isinstance(module[0], Downsample): + hs_a.append( + rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t)) + hs.append(h) + hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t)) + + if features_adapter is not None: + assert len( + features_adapter) == adapter_idx, 'Wrong features_adapter' + h = self.middle_block(h, emb, context=context, batch_size=b) + hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t)) + + hs_out = [] + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context=context, batch_size=b) + if isinstance(module[-1], Upsample): + hs_a.append( + rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t)) + hs_out.append(h) + h = h.type(x.dtype) + hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t)) + + y = self.out(h) + y = rearrange(y, '(b t) c h w -> b c t h w', b=b) + + if not self.base_model_gen_only: + ba, _, _ = x_action.shape + a_y = self.action_unet(x_action, timesteps[:ba], hs_a, + context_action[:2], **kwargs) + # Predict state + if b > 1: + s_y = self.state_unet(x_state, timesteps[:ba], hs_a, + context_action[:2], **kwargs) + else: + s_y = self.state_unet(x_state, timesteps, hs_a, + context_action[:2], **kwargs) + else: + a_y = torch.zeros_like(x_action) + s_y = torch.zeros_like(x_state) + + return y, a_y, s_y diff --git a/src/unifolm_wma/modules/vision/__init__.py b/src/unifolm_wma/modules/vision/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/unifolm_wma/modules/vision/base_vision.py b/src/unifolm_wma/modules/vision/base_vision.py new file mode 100644 index 0000000..d6f09ef --- /dev/null +++ b/src/unifolm_wma/modules/vision/base_vision.py @@ -0,0 +1,244 @@ +""" +base_vision.py + +Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility +functions, and initialization logic. + +We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision +Transformer model for feature extraction. +""" + +import timm +import torch +import torch.nn as nn +import torchvision.transforms.functional as TVF + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union + +from PIL.Image import Image +from timm.models.vision_transformer import Block, VisionTransformer +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy +from torchvision.transforms import Compose, Resize + + +# === Utility Functions for Monkey-Patching === +def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: + + def wrapper(*args: Any, **kwargs: Any) -> Any: + result = fn(*args, **kwargs) + return result[0] if isinstance(result, tuple) else result + + return wrapper + + +# === Interface for an Image Transform === +class ImageTransform(Protocol): + + def __call__( + self, img: Image, + **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + ... + + +# === Custom Torchvision Image Transforms === +@dataclass +class LetterboxPad: + padding_fill_value: Tuple[int, int, int] + + def __call__(self, image: Image) -> Image: + """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" + (w, h), max_wh = image.size, max(image.size) + horizontal_pad, vertical_pad = int((max_wh - w) / 2), int( + (max_wh - h) / 2) + padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) + return TVF.pad(image, + padding, + fill=self.padding_fill_value, + padding_mode="constant") + + +# === Abstract Base Class for arbitrary Vision Backbones === +class VisionBackbone(nn.Module, ABC): + + def __init__(self, + vision_backbone_id: str, + image_resize_strategy: str, + default_image_size: int = 224) -> None: + super().__init__() + self.identifier: str = vision_backbone_id + self.image_resize_strategy: str = image_resize_strategy + self.default_image_size: int = default_image_size + + # Instance attributes for a Vision Backbone + self.featurizer: nn.Module = None + self.image_transform: ImageTransform = None + + def get_image_transform(self) -> ImageTransform: + return self.image_transform + + @abstractmethod + def get_fsdp_wrapping_policy(self) -> Callable: + ... + + @abstractmethod + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features.""" + raise NotImplementedError + + @property + @abstractmethod + def default_image_resolution(self) -> Tuple[int, int, int]: + ... + + @property + @abstractmethod + def embed_dim(self) -> int: + ... + + @property + @abstractmethod + def num_patches(self) -> int: + ... + + @property + @abstractmethod + def half_precision_dtype(self) -> torch.dtype: + ... + + +# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones === +class TimmViTBackbone(VisionBackbone, ABC): + + def __init__( + self, + vision_backbone_id: str, + timm_path_or_url: str, + image_resize_strategy: str, + default_image_size: int = 224, + override_act_layer: Optional[str] = None, + ) -> None: + super().__init__(vision_backbone_id, + image_resize_strategy, + default_image_size=default_image_size) + self.timm_path_or_url = timm_path_or_url + self.override_act_layer = override_act_layer + self.dtype = torch.bfloat16 + + # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary + if self.override_act_layer is None: + self.featurizer: VisionTransformer = timm.create_model( + self.timm_path_or_url, + pretrained=True, + num_classes=0, + img_size=self.default_image_size) + else: + self.featurizer: VisionTransformer = timm.create_model( + self.timm_path_or_url, + pretrained=True, + num_classes=0, + img_size=self.default_image_size, + act_layer=self.override_act_layer, + ) + self.featurizer.eval() + + # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility + # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! + # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 + self.featurizer.forward = unpack_tuple( + partial(self.featurizer.get_intermediate_layers, + n={len(self.featurizer.blocks) - 2})) + + # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!) + assert isinstance(self.featurizer, VisionTransformer), ( + "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, " + "file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!" + ) + + # Get Config =>> Note :: Override default image size to ensure correct image transform + self.data_cfg = timm.data.resolve_model_data_config(self.featurizer) + self.data_cfg["input_size"] = (3, self.default_image_size, + self.default_image_size) + + # Initialize Default Image Transform --> Modified by `self.image_resize_strategy` + default_image_transform = timm.data.create_transform(**self.data_cfg, + is_training=False) + + # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)! + if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url: + assert isinstance(default_image_transform, + Compose), "Unexpected `default_image_transform`!" + assert isinstance(default_image_transform.transforms[0], Resize) + default_image_transform = Compose([ + Resize(self.default_image_size, + interpolation=default_image_transform.transforms[0]. + interpolation), + *default_image_transform.transforms[1:], + ]) + + # Switch on `image_resize_strategy` + if self.image_resize_strategy == "resize-naive": + assert isinstance(default_image_transform, + Compose), "Unexpected `default_image_transform`!" + assert isinstance(default_image_transform.transforms[0], Resize) + + target_size = (self.default_image_size, self.default_image_size) + self.image_transform = Compose([ + Resize(target_size, + interpolation=default_image_transform.transforms[0]. + interpolation), + *default_image_transform.transforms[1:], + ]) + + elif self.image_resize_strategy == "resize-crop": + self.image_transform = default_image_transform + + elif self.image_resize_strategy == "letterbox": + assert isinstance(default_image_transform, + Compose), "Unexpected `default_image_transform`!" + assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!" + + # Compute Padding Fill Value (rescaled normalization mean if applicable) + fill = tuple([int(x * 255) for x in self.data_cfg["mean"]]) + + # Build New Transform + self.image_transform = Compose( + [LetterboxPad(fill), *default_image_transform.transforms]) + + else: + raise ValueError( + f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!" + ) + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer.""" + vit_wrap_policy = partial(_module_wrap_policy, + module_classes={VisionTransformer}) + transformer_block_policy = partial(transformer_auto_wrap_policy, + transformer_layer_cls={Block}) + return partial(_or_policy, + policies=[vit_wrap_policy, transformer_block_policy]) + + def forward( + self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]] + ) -> torch.Tensor: + """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features.""" + return self.featurizer(pixel_values) + + @property + def default_image_resolution(self) -> Tuple[int, int, int]: + return self.data_cfg["input_size"] + + @property + def embed_dim(self) -> int: + return self.featurizer.embed_dim + + @property + def num_patches(self) -> int: + return self.featurizer.patch_embed.num_patches + + @property + def half_precision_dtype(self) -> torch.dtype: + return self.dtype diff --git a/src/unifolm_wma/modules/vision/dinosiglip_vit.py b/src/unifolm_wma/modules/vision/dinosiglip_vit.py new file mode 100644 index 0000000..c688678 --- /dev/null +++ b/src/unifolm_wma/modules/vision/dinosiglip_vit.py @@ -0,0 +1,273 @@ +""" +dinosiglip_vit.py + +Vision backbone that returns concatenated features from both DINOv2 and SigLIP. +""" + +import timm +import torch +import torchvision.transforms as transforms + +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, Tuple +from PIL import Image +from timm.models.vision_transformer import Block, VisionTransformer +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy +from torchvision.transforms import Compose, Resize, Normalize + +from unifolm_wma.modules.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple +from unifolm_wma.utils.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector + +# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers) +DINOSigLIP_VISION_BACKBONES = { + "dinosiglip-vit-so-224px": { + "dino": "vit_large_patch14_reg4_dinov2.lvd142m", + "siglip": "vit_so400m_patch14_siglip_224", + }, + "dinosiglip-vit-so-384px": { + "dino": "vit_large_patch14_reg4_dinov2.lvd142m", + "siglip": "vit_so400m_patch14_siglip_384", + }, +} + + +@dataclass +class DinoSigLIPImageTransform: + dino_image_transform: ImageTransform + siglip_image_transform: ImageTransform + is_prismatic: bool = True + + def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: + return { + "dino": self.dino_image_transform(img, **kwargs), + "siglip": self.siglip_image_transform(img, **kwargs) + } + + +class DinoSigLIPViTBackbone(VisionBackbone): + + def __init__(self, + vision_backbone_id: str, + image_resize_strategy: str, + arch_specifier: str, + output_dim: int, + pretrained_checkpoint=None, + freeze=True, + default_image_size: int = 224) -> None: + super().__init__(vision_backbone_id, + image_resize_strategy, + default_image_size=default_image_size) + self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[ + vision_backbone_id]["dino"] + self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[ + vision_backbone_id]["siglip"] + + # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary + self.dino_featurizer: VisionTransformer = timm.create_model( + self.dino_timm_path_or_url, + pretrained=True, + num_classes=0, + img_size=self.default_image_size) + if pretrained_checkpoint: + ckpt = pretrained_checkpoint + '/openvla_dino.pt' + self.dino_featurizer.load_state_dict( + torch.load(ckpt, weights_only=True)) + print('>>> load dino weights') + if freeze: + self.dino_featurizer.eval() + for param in self.dino_featurizer.parameters(): + param.requires_grad = False + + self.siglip_featurizer: VisionTransformer = timm.create_model( + self.siglip_timm_path_or_url, + pretrained=True, + num_classes=0, + img_size=self.default_image_size) + if pretrained_checkpoint: + ckpt = pretrained_checkpoint + '/openvla_siglip.pt' + self.siglip_featurizer.load_state_dict( + torch.load(ckpt, weights_only=True)) + print('>>> load siglip weights') + if freeze: + self.siglip_featurizer.eval() + for param in self.siglip_featurizer.parameters(): + param.requires_grad = False + + # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility + # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! + # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 + self.dino_featurizer.forward = unpack_tuple( + partial(self.dino_featurizer.get_intermediate_layers, + n={len(self.dino_featurizer.blocks) - 2})) + self.siglip_featurizer.forward = unpack_tuple( + partial(self.siglip_featurizer.get_intermediate_layers, + n={len(self.siglip_featurizer.blocks) - 2})) + + # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models + self.dino_data_cfg = timm.data.resolve_model_data_config( + self.dino_featurizer) + self.dino_data_cfg["input_size"] = (3, self.default_image_size, + self.default_image_size) + + self.siglip_data_cfg = timm.data.resolve_model_data_config( + self.siglip_featurizer) + self.siglip_data_cfg["input_size"] = (3, self.default_image_size, + self.default_image_size) + + # Initialize *both* Transforms + self.default_dino_transform = timm.data.create_transform( + **self.dino_data_cfg, is_training=False) + self.default_siglip_transform = timm.data.create_transform( + **self.siglip_data_cfg, is_training=False) + + # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!! + assert isinstance(self.default_siglip_transform, + Compose), "Unexpected `default_image_transform`!" + assert isinstance(self.default_siglip_transform.transforms[0], Resize) + self.default_siglip_transform = Compose([ + Resize(self.default_image_size, + interpolation=self.default_siglip_transform.transforms[0]. + interpolation), + *self.default_siglip_transform.transforms[1:], + ]) + + if self.image_resize_strategy == "resize-naive": + assert isinstance( + self.default_dino_transform, + Compose), "Unexpected `default_dino_image_transform`!" + assert isinstance( + self.default_siglip_transform, + Compose), "Unexpected `default_siglip_image_transform`!" + assert isinstance(self.default_dino_transform.transforms[0], + Resize) + assert isinstance(self.default_siglip_transform.transforms[0], + Resize) + + self.target_size = (self.default_image_size, + self.default_image_size) + dino_transform = Compose([ + Resize(self.target_size, + interpolation=self.default_dino_transform.transforms[0]. + interpolation), + *self.default_dino_transform.transforms[1:], + ]) + siglip_transform = Compose([ + Resize(self.target_size, + interpolation=self.default_siglip_transform. + transforms[0].interpolation), + *self.default_siglip_transform.transforms[1:], + ]) + + self.image_transform = DinoSigLIPImageTransform( + dino_transform, siglip_transform) + + elif self.image_resize_strategy == "resize-crop": + self.image_transform = DinoSigLIPImageTransform( + self.default_dino_transform, self.default_siglip_transform) + + elif self.image_resize_strategy == "letterbox": + assert isinstance(self.default_dino_transform, + Compose), "Unexpected `default_dino_transform`!" + assert isinstance( + self.default_siglip_transform, + Compose), "Unexpected `default_siglip_transform`!" + assert ("mean" in self.dino_data_cfg + and "mean" in self.siglip_data_cfg + ), "DinoSigLIP `data_cfg` missing `mean`!" + + # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) + dino_fill = tuple( + [int(x * 255) for x in self.dino_data_cfg["mean"]]) + siglip_fill = tuple( + [int(x * 255) for x in self.siglip_data_cfg["mean"]]) + + # Build New Transform + self.image_transform = DinoSigLIPImageTransform( + Compose([ + LetterboxPad(dino_fill), + *self.default_dino_transform.transforms + ]), + Compose([ + LetterboxPad(siglip_fill), + *self.default_siglip_transform.transforms + ]), + ) + + else: + raise ValueError( + f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!" + ) + + self.arch_specifier = arch_specifier + if arch_specifier == "linear": + self.projector = LinearProjector(self.embed_dim, output_dim) + elif arch_specifier.endswith("fused-gelu-mlp"): + self.projector = FusedMLPProjector(self.embed_dim, output_dim) + elif arch_specifier.endswith("gelu-mlp"): + self.projector = MLPProjector(self.embed_dim, output_dim) + else: + raise ValueError( + f"PrismaticVLM with `{arch_specifier = }` is not supported!") + + self.on_gpu = False + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" + vit_wrap_policy = partial(_module_wrap_policy, + module_classes={VisionTransformer}) + transformer_block_policy = partial(transformer_auto_wrap_policy, + transformer_layer_cls={Block}) + return partial(_or_policy, + policies=[vit_wrap_policy, transformer_block_policy]) + + def forward(self, img) -> torch.Tensor: + img = torch.clamp(img.float(), -1., 1.) + img = (img + 1.0) / 2.0 + img = img * 255 + + resize = transforms.Resize(min(self.target_size), + interpolation=self.default_dino_transform. + transforms[0].interpolation, + max_size=None, + antialias=True) + center_crop = transforms.CenterCrop(self.target_size) + img = center_crop(resize(img)) + + dino_normalizer = Normalize(mean=torch.tensor([0.4850, 0.4560, + 0.4060]), + std=torch.tensor([0.2290, 0.2240, 0.2250])) + siglip_normalizer = Normalize( + mean=torch.tensor([0.5000, 0.5000, 0.5000]), + std=torch.tensor([0.5000, 0.5000, 0.5000])) + pixel_values = { + 'dino': dino_normalizer(img), + 'siglip': siglip_normalizer(img) + } + + if self.on_gpu: + pixel_values = {k: v.cuda() for k, v in pixel_values.items()} + elif next(self.dino_featurizer.parameters()).device.type != 'cpu': + self.on_gpu = True + """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" + dino_patches = self.dino_featurizer(pixel_values["dino"]) + siglip_patches = self.siglip_featurizer(pixel_values["siglip"]) + + return self.projector(torch.cat([dino_patches, siglip_patches], dim=2)) + + @property + def default_image_resolution(self) -> Tuple[int, int, int]: + return self.dino_data_cfg["input_size"] + + @property + def embed_dim(self) -> int: + return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim + + @property + def num_patches(self) -> int: + assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches + return self.dino_featurizer.patch_embed.num_patches + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/src/unifolm_wma/utils/basics.py b/src/unifolm_wma/utils/basics.py new file mode 100644 index 0000000..088298b --- /dev/null +++ b/src/unifolm_wma/utils/basics.py @@ -0,0 +1,104 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +import torch.nn as nn +from unifolm_wma.utils.utils import instantiate_from_config + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def nonlinearity(type='silu'): + if type == 'silu': + return nn.SiLU() + elif type == 'leaky_relu': + return nn.LeakyReLU() + + +class GroupNormSpecific(nn.GroupNorm): + + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels, num_groups=32): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNormSpecific(num_groups, channels) + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config( + c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} diff --git a/src/unifolm_wma/utils/callbacks.py b/src/unifolm_wma/utils/callbacks.py new file mode 100644 index 0000000..0458394 --- /dev/null +++ b/src/unifolm_wma/utils/callbacks.py @@ -0,0 +1,226 @@ +import os +import time +import logging +import json + +mainlogger = logging.getLogger('mainlogger') + +import torch +import torchvision +import pytorch_lightning as pl +import matplotlib.pyplot as plt + +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info + +from unifolm_wma.utils.save_video import log_local, prepare_to_log + +STAT_DIR = '~/' + + +class ImageLogger(Callback): + + def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=True, save_dir=None, \ + to_local=False, log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.to_local = to_local + self.clamp = clamp + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.save_stat_dir = os.path.join(save_dir, "stat") + os.makedirs(self.save_stat_dir, exist_ok=True) + self.fps_stat = {} + self.fs_stat = {} + if self.to_local: + self.save_dir = os.path.join(save_dir, "images") + os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True) + os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True) + self.count_data = 0 + + def log_to_tensorboard(self, + pl_module, + batch_logs, + filename, + split, + save_fps=8): + """ log images and videos to tensorboard """ + global_step = pl_module.global_step + for key in batch_logs: + value = batch_logs[key] + tag = "gs%d-%s/%s||%s||%s||%s" % ( + global_step, split, key, + batch_logs['condition'][0].split('_')[0], + batch_logs['condition'][0].split('_')[1], + batch_logs['video_idx']) + if isinstance(value, list) and isinstance(value[0], str): + captions = ' |------| '.join(value) + pl_module.logger.experiment.add_text(tag, + captions, + global_step=global_step) + elif isinstance(value, torch.Tensor) and value.dim() == 5: + video = value + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, + nrow=int(n), + padding=0) + for framesheet in video + ] + grid = torch.stack( + frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = grid.unsqueeze(dim=0) + pl_module.logger.experiment.add_video(tag, + grid, + fps=save_fps, + global_step=global_step) + elif isinstance(value, torch.Tensor) and value.dim() == 4: + img = value + grid = torchvision.utils.make_grid(img, nrow=int(n), padding=0) + grid = (grid + 1.0) / 2.0 + pl_module.logger.experiment.add_image(tag, + grid, + global_step=global_step) + elif isinstance(value, torch.Tensor) and value.dim() == 3: + b, _, _ = value.shape + value1 = value[:b // 2, ...] + value2 = value[b // 2:, ...] + _, num_points, d = value1.shape + for i in range(d): + data1 = value1[0, :, i].cpu().detach().numpy() + data2 = value2[0, :, i].cpu().detach().numpy() + fig, ax = plt.subplots() + ax.plot(data1, label='Target 1') + ax.plot(data2, label='Sample 1') + ax.set_title(f'Comparison at dimension {i} for {key}') + ax.legend() + pl_module.logger.experiment.add_figure( + tag + f"| {key}_dim_{i}", fig, global_step=global_step) + plt.close(fig) + else: + pass + + @rank_zero_only + def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"): + """ generate images, then save and log to tensorboard """ + # Update fps and fs statistics + batch_fps = batch['fps'].tolist() + batch_fs = batch['frame_stride'].tolist() + for num in batch_fps: + self.fps_stat[num] = self.fps_stat.get(num, 0) + 1 + for num in batch_fs: + self.fs_stat[num] = self.fs_stat.get(num, 0) + 1 + skip_freq = self.batch_freq if split == "train" else 5 + ## NOTE HAND CODE + self.count_data += 12.5 * 2 + if self.count_data >= skip_freq: + self.count_data = 0 + + is_train = pl_module.training + if is_train: + pl_module.eval() + torch.cuda.empty_cache() + with torch.no_grad(): + log_func = pl_module.log_images + batch_logs = log_func(batch, + split=split, + **self.log_images_kwargs) + # Log fps and fs statistics + with open(self.save_stat_dir + '/fps_fs_stat.json', + 'w') as file: + json.dump({ + 'fps': self.fps_stat, + 'fs': self.fs_stat + }, + file, + indent=4) + + batch_logs = prepare_to_log(batch_logs, self.max_images, + self.clamp) + torch.cuda.empty_cache() + + filename = "ep{}_idx{}_rank{}".format(pl_module.current_epoch, + batch_idx, + pl_module.global_rank) + if self.to_local: + mainlogger.info("Log [%s] batch <%s> to local ..." % + (split, filename)) + filename = "gs{}_".format(pl_module.global_step) + filename + log_local(batch_logs, + os.path.join(self.save_dir, split), + filename, + save_fps=10) + else: + mainlogger.info("Log [%s] batch <%s> to tensorboard ..." % + (split, filename)) + self.log_to_tensorboard(pl_module, + batch_logs, + filename, + split, + save_fps=10) + mainlogger.info('Finish!') + + if is_train: + pl_module.train() + + def on_train_batch_end(self, + trainer, + pl_module, + outputs, + batch, + batch_idx, + dataloader_idx=None): + if self.batch_freq != -1 and pl_module.logdir: + self.log_batch_imgs(pl_module, batch, batch_idx, split="train") + + def on_validation_batch_end(self, + trainer, + pl_module, + outputs, + batch, + batch_idx, + dataloader_idx=None): + #Different with validation_step() that saving the whole validation set and only keep the latest, + #It records the performance of every validation (without overwritten) by only keep a subset + if self.batch_freq != -1 and pl_module.logdir: + self.log_batch_imgs(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm + and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # See https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + # Lightning update + if int((pl.__version__).split('.')[1]) >= 7: + gpu_index = trainer.strategy.root_device.index + else: + gpu_index = trainer.root_gpu + torch.cuda.reset_peak_memory_stats(gpu_index) + torch.cuda.synchronize(gpu_index) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + if int((pl.__version__).split('.')[1]) >= 7: + gpu_index = trainer.strategy.root_device.index + else: + gpu_index = trainer.root_gpu + torch.cuda.synchronize(gpu_index) + max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2**20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass diff --git a/src/unifolm_wma/utils/common.py b/src/unifolm_wma/utils/common.py new file mode 100644 index 0000000..f5b8a03 --- /dev/null +++ b/src/unifolm_wma/utils/common.py @@ -0,0 +1,111 @@ +import math +from inspect import isfunction +import torch +from torch import Tensor, nn +import torch.distributed as dist + + +def gather_data(data, return_np=True): + ''' gather data from multiple processes to one list ''' + data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] + dist.all_gather(data_list, data) # gather not supported with NCCL + if return_np: + data_list = [data.cpu().numpy() for data in data_list] + return data_list + + +def autocast(f): + + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast( + enabled=True, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled()): + return f(*args, **kwargs) + + return do_autocast + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1, ) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1, ) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def exists(val): + return val is not None + + +def identity(*args, **kwargs): + return nn.Identity() + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def shape_to_str(x): + shape_str = "x".join([str(x) for x in x.shape]) + return shape_str + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +ckpt = torch.utils.checkpoint.checkpoint + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + return ckpt(func, *inputs, use_reentrant=False) + else: + return func(*inputs) diff --git a/src/unifolm_wma/utils/data.py b/src/unifolm_wma/utils/data.py new file mode 100644 index 0000000..589da5c --- /dev/null +++ b/src/unifolm_wma/utils/data.py @@ -0,0 +1,242 @@ +import os, sys +import numpy as np +import torch +import pytorch_lightning as pl + +from functools import partial +from torch.utils.data import (DataLoader, Dataset, ConcatDataset, + WeightedRandomSampler) +from unifolm_wma.data.base import Txt2ImgIterableBaseDataset +from unifolm_wma.utils.utils import instantiate_from_config + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, Txt2ImgIterableBaseDataset): + split_size = dataset.num_records // worker_info.num_workers + # Reset num_records to the true number to retain reliable length information + dataset.sample_ids = dataset.valid_ids[worker_id * + split_size:(worker_id + 1) * + split_size] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + else: + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +class DataModuleFromConfig(pl.LightningDataModule): + + def __init__(self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=True, + train_img=None, + dataset_and_weights=None): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = dict() + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, + shuffle=shuffle_val_dataloader) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, + shuffle=shuffle_test_loader) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + + self.img_loader = None + self.wrap = wrap + self.collate_fn = None + self.dataset_weights = dataset_and_weights + assert round(sum(self.dataset_weights.values()), + 2) == 1.0, "The sum of dataset weights != 1.0" + + def prepare_data(self): + pass + + def setup(self, stage=None): + if 'train' in self.dataset_configs: + self.train_datasets = dict() + for dataname in self.dataset_weights: + data_dir = self.dataset_configs['train']['params']['data_dir'] + transition_dir = '/'.join([data_dir, 'transitions']) + csv_file = f'{dataname}.csv' + meta_path = '/'.join([data_dir, csv_file]) + self.dataset_configs['train']['params'][ + 'meta_path'] = meta_path + self.dataset_configs['train']['params'][ + 'transition_dir'] = transition_dir + self.dataset_configs['train']['params'][ + 'dataset_name'] = dataname + self.train_datasets[dataname] = instantiate_from_config( + self.dataset_configs['train']) + + # Setup validation dataset + if 'validation' in self.dataset_configs: + self.val_datasets = dict() + for dataname in self.dataset_weights: + data_dir = self.dataset_configs['validation']['params'][ + 'data_dir'] + transition_dir = '/'.join([data_dir, 'transitions']) + csv_file = f'{dataname}.csv' + meta_path = '/'.join([data_dir, csv_file]) + self.dataset_configs['validation']['params'][ + 'meta_path'] = meta_path + self.dataset_configs['validation']['params'][ + 'transition_dir'] = transition_dir + self.dataset_configs['validation']['params'][ + 'dataset_name'] = dataname + self.val_datasets[dataname] = instantiate_from_config( + self.dataset_configs['validation']) + + # Setup test dataset + if 'test' in self.dataset_configs: + self.test_datasets = dict() + for dataname in self.dataset_weights: + data_dir = self.dataset_configs['test']['params']['data_dir'] + transition_dir = '/'.join([data_dir, 'transitions']) + csv_file = f'{dataname}.csv' + meta_path = '/'.join([data_dir, csv_file]) + self.dataset_configs['test']['params']['meta_path'] = meta_path + self.dataset_configs['test']['params'][ + 'transition_dir'] = transition_dir + self.dataset_configs['test']['params'][ + 'dataset_name'] = dataname + self.test_datasets[dataname] = instantiate_from_config( + self.dataset_configs['test']) + + if self.wrap: + for k in self.datasets: + self.datasets[k] = WrappedDataset(self.datasets[k]) + + def _train_dataloader(self): + is_iterable_dataset = False # NOTE Hand Code + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + combined_dataset = [] + sample_weights = [] + for dataname, dataset in self.train_datasets.items(): + combined_dataset.append(dataset) + sample_weights.append( + torch.full((len(dataset), ), + self.dataset_weights[dataname] / len(dataset))) + combined_dataset = ConcatDataset(combined_dataset) + sample_weights = torch.cat(sample_weights) + sampler = WeightedRandomSampler(sample_weights, + num_samples=len(combined_dataset), + replacement=True) + loader = DataLoader(combined_dataset, + sampler=sampler, + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + collate_fn=self.collate_fn, + drop_last=True + ) + return loader + + def _val_dataloader(self, shuffle=False): + is_iterable_dataset = False # NOTE Hand Code + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + combined_dataset = [] + sample_weights = [] + for dataname, dataset in self.val_datasets.items(): + combined_dataset.append(dataset) + sample_weights.append( + torch.full((len(dataset), ), + self.dataset_weights[dataname] / len(dataset))) + combined_dataset = ConcatDataset(combined_dataset) + sample_weights = torch.cat(sample_weights) + sampler = WeightedRandomSampler(sample_weights, + num_samples=len(combined_dataset), + replacement=True) + loader = DataLoader(combined_dataset, + sampler=sampler, + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + collate_fn=self.collate_fn) + return loader + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = False # NOTE Hand Code + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + combined_dataset = [] + sample_weights = [] + for dataname, dataset in self.test_datasets.items(): + combined_dataset.append(dataset) + sample_weights.append( + torch.full((len(dataset), ), + self.dataset_weights[dataname] / len(dataset))) + combined_dataset = ConcatDataset(combined_dataset) + sample_weights = torch.cat(sample_weights) + sampler = WeightedRandomSampler(sample_weights, + num_samples=len(combined_dataset), + replacement=True) + loader = DataLoader(combined_dataset, + sampler=sampler, + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + collate_fn=self.collate_fn) + return loader + + def _predict_dataloader(self, shuffle=False): + if isinstance(self.datasets['predict'], + Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader( + self.datasets["predict"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + collate_fn=self.collate_fn, + ) + + def __len__(self): + count = 0 + for _, values in self.train_datasets.items(): + count += len(values) + return count diff --git a/src/unifolm_wma/utils/diffusion.py b/src/unifolm_wma/utils/diffusion.py new file mode 100644 index 0000000..ee92eab --- /dev/null +++ b/src/unifolm_wma/utils/diffusion.py @@ -0,0 +1,191 @@ +import math +import numpy as np +import torch +import torch.nn.functional as F +from einops import repeat + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * + torch.arange(start=0, end=half, dtype=torch.float32) / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def make_beta_schedule(schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + if schedule == "linear": + betas = (torch.linspace(linear_start**0.5, + linear_end**0.5, + n_timestep, + dtype=torch.float64)**2) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + + cosine_s) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, + linear_end, + n_timestep, + dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, + linear_end, + n_timestep, + dtype=torch.float64)**0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, + num_ddim_timesteps, + num_ddpm_timesteps, + verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + steps_out = ddim_timesteps + 1 + elif ddim_discr_method == 'uniform_trailing': + c = num_ddpm_timesteps / num_ddim_timesteps + ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0, + -c))).astype(np.int64) + steps_out = ddim_timesteps - 1 + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), + num_ddim_timesteps))**2).astype(int) + steps_out = ddim_timesteps + 1 + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + # steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, + ddim_timesteps, + eta, + verbose=True): + # select alphas for computing the variance schedule + # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}') + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print( + f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}' + ) + print( + f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}' + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + Args: + betas (`numpy.ndarray`): + the betas that the scheduler is being initialized with. + + Returns: + `numpy.ndarray`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_bar_sqrt = np.sqrt(alphas_cumprod) + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - + alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = np.concatenate([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), + keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # Rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # Mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + ( + 1 - guidance_rescale) * noise_cfg + return noise_cfg diff --git a/src/unifolm_wma/utils/distributions.py b/src/unifolm_wma/utils/distributions.py new file mode 100644 index 0000000..8b04a67 --- /dev/null +++ b/src/unifolm_wma/utils/distributions.py @@ -0,0 +1,94 @@ +import torch +import numpy as np + + +class AbstractDistribution: + + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean).to(device=self.parameters.device) + + def sample(self, noise=None): + if noise is None: + noise = torch.randn(self.mean.shape) + + x = self.mean + self.std * noise.to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2)**2) * torch.exp(-logvar2)) diff --git a/src/unifolm_wma/utils/ema.py b/src/unifolm_wma/utils/ema.py new file mode 100644 index 0000000..d3898a8 --- /dev/null +++ b/src/unifolm_wma/utils/ema.py @@ -0,0 +1,84 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + 'num_updates', + torch.tensor(0, dtype=torch.int) + if use_num_upates else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #Remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, + (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as( + m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * + (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_( + shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/src/unifolm_wma/utils/nn_utils.py b/src/unifolm_wma/utils/nn_utils.py new file mode 100644 index 0000000..15dabd3 --- /dev/null +++ b/src/unifolm_wma/utils/nn_utils.py @@ -0,0 +1,66 @@ +""" +nn_utils.py + +Utility functions and PyTorch submodule definitions. +""" + +import torch +import torch.nn as nn + + +# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] === +class LinearProjector(nn.Module): + + def __init__(self, vision_dim: int, llm_dim: int) -> None: + super().__init__() + self.projector = nn.Linear(vision_dim, llm_dim, bias=True) + + def forward(self, img_patches: torch.Tensor) -> torch.Tensor: + return self.projector(img_patches) + + +class MLPProjector(nn.Module): + + def __init__(self, + vision_dim: int, + llm_dim: int, + mlp_type: str = "gelu-mlp") -> None: + super().__init__() + if mlp_type == "gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(vision_dim, llm_dim, bias=True), + nn.GELU(), + nn.Linear(llm_dim, llm_dim, bias=True), + ) + else: + raise ValueError( + f"Projector with `{mlp_type = }` is not supported!") + + def forward(self, img_patches: torch.Tensor) -> torch.Tensor: + return self.projector(img_patches) + + +class FusedMLPProjector(nn.Module): + + def __init__(self, + fused_vision_dim: int, + llm_dim: int, + mlp_type: str = "fused-gelu-mlp") -> None: + super().__init__() + self.initial_projection_dim = fused_vision_dim * 4 + if mlp_type == "fused-gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(fused_vision_dim, + self.initial_projection_dim, + bias=True), + nn.GELU(), + nn.Linear(self.initial_projection_dim, llm_dim, bias=True), + nn.GELU(), + nn.Linear(llm_dim, llm_dim, bias=True), + ) + else: + raise ValueError( + f"Fused Projector with `{mlp_type = }` is not supported!") + + def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: + return self.projector(fused_img_patches) diff --git a/src/unifolm_wma/utils/projector.py b/src/unifolm_wma/utils/projector.py new file mode 100644 index 0000000..579b7c2 --- /dev/null +++ b/src/unifolm_wma/utils/projector.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn + + +class LinearProjector(nn.Module): + def __init__(self, input_dim: int, output_dim: int) -> None: + super().__init__() + self.projector = nn.Linear(input_dim, output_dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.projector(x) + + +class MLPProjector(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + mlp_type: str = "gelu-mlp") -> None: + super().__init__() + if mlp_type == "gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(vision_dim, llm_dim, bias=True), + nn.GELU(approximate='tanh'), + nn.Linear(llm_dim, llm_dim, bias=True), + ) + elif mlp_type == "silu-mlp": + self.projector = nn.Sequential( + nn.Linear(vision_dim, llm_dim, bias=True), + nn.SiLU(), + nn.Linear(llm_dim, llm_dim, bias=True), + ) + else: + raise ValueError(f"Projector with `{mlp_type = }` is not supported!") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.projector(x) + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +def FeedForward(dim, mult=4, ffd_type="gelu-ffd"): + inner_dim = int(dim * mult) + if ffd_type = "gelu-ffd": + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(approximate='tanh'), + nn.Linear(inner_dim, dim, bias=False), + ) + elif ffd_type = "silu-ffd": + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.SiLU(), + nn.Linear(inner_dim, dim, bias=False), + ) + else: + raise ValueError(f"Projector with `{mlp_type = }` is not supported!") + + +class TokenProjector(nn.Module): + def __init__( + self, + dim=1024, + depth=1, + dim_head=64, + heads=16, + num_queries=16, + output_dim=1024, + ff_mult=4, + chunck_size=None, + ): + super().__init__() + self.num_queries = num_queries + self.chunck_size = chunck_size + if chunck_size is not None: + num_queries = num_queries * chunck_size + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x) + latents = self.latents.repeat(x.size(0), 1, 1) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + latents = self.norm_out(latents) diff --git a/src/unifolm_wma/utils/save_video.py b/src/unifolm_wma/utils/save_video.py new file mode 100644 index 0000000..83f475b --- /dev/null +++ b/src/unifolm_wma/utils/save_video.py @@ -0,0 +1,258 @@ +import os +import torch +import numpy as np +import torchvision + +from tqdm import tqdm +from PIL import Image +from einops import rearrange +from torch import Tensor +from torchvision.utils import make_grid +from torchvision.transforms.functional import to_tensor + + +def frames_to_mp4(frame_dir, output_path, fps): + + def read_first_n_frames(d: os.PathLike, num_frames: int): + if num_frames: + images = [ + Image.open(os.path.join(d, f)) + for f in sorted(os.listdir(d))[:num_frames] + ] + else: + images = [ + Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d)) + ] + images = [to_tensor(x) for x in images] + return torch.stack(images) + + videos = read_first_n_frames(frame_dir, num_frames=None) + videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1) + torchvision.io.write_video(output_path, + videos, + fps=fps, + video_codec='h264', + options={'crf': '10'}) + + +def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None): + """ + video: torch.Tensor, b,c,t,h,w, 0-1 + if -1~1, enable rescale=True + """ + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + nrow = int(np.sqrt(n)) if nrow is None else nrow + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=nrow, padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, + dim=0) + grid = torch.clamp(grid.float(), -1., 1.) + if rescale: + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute( + 0, 2, 3, 1) + torchvision.io.write_video(savepath, + grid, + fps=fps, + video_codec='h264', + options={'crf': '10'}) + + +def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True): + assert (video.dim() == 5) + assert (isinstance(video, torch.Tensor)) + + video = video.detach().cpu() + if clamp: + video = torch.clamp(video, -1., 1.) + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n))) + for framesheet in video + ] + grid = torch.stack(frame_grids, + dim=0) + if rescale: + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute( + 0, 2, 3, 1) + path = os.path.join(root, filename) + torchvision.io.write_video(path, + grid, + fps=fps, + video_codec='h264', + options={'crf': '10'}) + + +def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True): + if batch_logs is None: + return None + """ save images and videos from images dict """ + + def save_img_grid(grid, path, rescale): + if rescale: + grid = (grid + 1.0) / 2.0 + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + for key in batch_logs: + value = batch_logs[key] + if isinstance(value, list) and isinstance(value[0], str): + # A batch of captions + path = os.path.join(save_dir, "%s-%s.txt" % (key, filename)) + with open(path, 'w') as f: + for i, txt in enumerate(value): + f.write(f'idx={i}, txt={txt}\n') + f.close() + elif isinstance(value, torch.Tensor) and value.dim() == 5: + # Save video grids + video = value + # Only save grayscale or rgb mode + if video.shape[1] != 1 and video.shape[1] != 3: + continue + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(1), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, + dim=0) + if rescale: + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename)) + torchvision.io.write_video(path, + grid, + fps=save_fps, + video_codec='h264', + options={'crf': '10'}) + + # Save frame sheet + img = value + video_frames = rearrange(img, 'b c t h w -> (b t) c h w') + t = img.shape[2] + grid = torchvision.utils.make_grid(video_frames, nrow=t, padding=0) + path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename)) + # Save_img_grid(grid, path, rescale) + elif isinstance(value, torch.Tensor) and value.dim() == 4: + # Save image grids + img = value + # Only save grayscale or rgb mode + if img.shape[1] != 1 and img.shape[1] != 3: + continue + n = img.shape[0] + grid = torchvision.utils.make_grid(img, nrow=1, padding=0) + path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename)) + save_img_grid(grid, path, rescale) + else: + pass + + +def prepare_to_log(batch_logs, max_images=100000, clamp=True): + if batch_logs is None: + return None + for key in batch_logs: + N = batch_logs[key].shape[0] if hasattr( + batch_logs[key], 'shape') else len(batch_logs[key]) + N = min(N, max_images) + batch_logs[key] = batch_logs[key][:N] + # In batch_logs: images & instruction + if isinstance(batch_logs[key], torch.Tensor): + batch_logs[key] = batch_logs[key].detach().cpu() + if clamp: + try: + batch_logs[key] = torch.clamp(batch_logs[key].float(), -1., + 1.) + except RuntimeError: + print("clamp_scalar_cpu not implemented for Half") + return batch_logs + + +# ---------------------------------------------------------------------------------------------- + + +def fill_with_black_squares(video, desired_len: int) -> Tensor: + if len(video) >= desired_len: + return video + + return torch.cat([ + video, + torch.zeros_like(video[0]).unsqueeze(0).repeat( + desired_len - len(video), 1, 1, 1), + ], + dim=0) + + +# ---------------------------------------------------------------------------------------------- +def load_num_videos(data_path, num_videos): + # First argument can be either data_path of np array + if isinstance(data_path, str): + videos = np.load(data_path)['arr_0'] # NTHWC + elif isinstance(data_path, np.ndarray): + videos = data_path + else: + raise Exception + + if num_videos is not None: + videos = videos[:num_videos, :, :, :, :] + return videos + + +def npz_to_video_grid(data_path, + out_path, + num_frames, + fps, + num_videos=None, + nrow=None, + verbose=True): + if isinstance(data_path, str): + videos = load_num_videos(data_path, num_videos) + elif isinstance(data_path, np.ndarray): + videos = data_path + else: + raise Exception + n, t, h, w, c = videos.shape + videos_th = [] + for i in range(n): + video = videos[i, :, :, :, :] + images = [video[j, :, :, :] for j in range(t)] + images = [to_tensor(img) for img in images] + video = torch.stack(images) + videos_th.append(video) + if verbose: + videos = [ + fill_with_black_squares(v, num_frames) + for v in tqdm(videos_th, desc='Adding empty frames') + ] + else: + videos = [fill_with_black_squares(v, num_frames) + for v in videos_th] # NTCHW + + frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) + if nrow is None: + nrow = int(np.ceil(np.sqrt(n))) + if verbose: + frame_grids = [ + make_grid(fs, nrow=nrow) + for fs in tqdm(frame_grids, desc='Making grids') + ] + else: + frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] + + if os.path.dirname(out_path) != "": + os.makedirs(os.path.dirname(out_path), exist_ok=True) + frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute( + 0, 2, 3, 1) + torchvision.io.write_video(out_path, + frame_grids, + fps=fps, + video_codec='h264', + options={'crf': '10'}) diff --git a/src/unifolm_wma/utils/train.py b/src/unifolm_wma/utils/train.py new file mode 100644 index 0000000..225fbf2 --- /dev/null +++ b/src/unifolm_wma/utils/train.py @@ -0,0 +1,231 @@ +import os +import logging + +mainlogger = logging.getLogger('mainlogger') + +import torch +import pandas as pd + +from omegaconf import OmegaConf +from collections import OrderedDict + + +def init_workspace(name, logdir, model_config, lightning_config, rank=0): + workdir = os.path.join(logdir, name) + ckptdir = os.path.join(workdir, "checkpoints") + cfgdir = os.path.join(workdir, "configs") + loginfo = os.path.join(workdir, "loginfo") + + # Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower) + os.makedirs(workdir, exist_ok=True) + os.makedirs(ckptdir, exist_ok=True) + os.makedirs(cfgdir, exist_ok=True) + os.makedirs(loginfo, exist_ok=True) + + if rank == 0: + if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks: + os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'), + exist_ok=True) + OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml")) + OmegaConf.save(OmegaConf.create({"lightning": lightning_config}), + os.path.join(cfgdir, "lightning.yaml")) + return workdir, ckptdir, cfgdir, loginfo + + +def check_config_attribute(config, name): + if name in config: + value = getattr(config, name) + return value + else: + return None + + +def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger): + default_callbacks_cfg = { + "model_checkpoint": { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch}", + "verbose": True, + "save_last": False, + } + }, + "batch_logger": { + "target": "unifolm_wma.utils.callbacks.ImageLogger", + "params": { + "save_dir": logdir, + "batch_frequency": 1000, + "max_images": 4, + "clamp": True, + } + }, + "learning_rate_logger": { + "target": "pytorch_lightning.callbacks.LearningRateMonitor", + "params": { + "logging_interval": "step", + "log_momentum": False + } + }, + "cuda_callback": { + "target": "unifolm_wma.utils.callbacks.CUDACallback", + }, + } + + # Optional setting for saving checkpoints + monitor_metric = check_config_attribute(config.model.params, "monitor") + if monitor_metric is not None: + mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.") + default_callbacks_cfg["model_checkpoint"]["params"][ + "monitor"] = monitor_metric + default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3 + default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min" + + if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks: + mainlogger.info( + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.' + ) + default_metrics_over_trainsteps_ckpt_dict = { + 'metrics_over_trainsteps_checkpoint': { + "target": 'pytorch_lightning.callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch}-{step}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } + } + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + + if "callbacks" in lightning_config: + callbacks_cfg = lightning_config.callbacks + else: + callbacks_cfg = OmegaConf.create() + callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + + return callbacks_cfg + + +def get_trainer_logger(lightning_config, logdir, on_debug): + default_logger_cfgs = { + "tensorboard": { + "target": "pytorch_lightning.loggers.TensorBoardLogger", + "params": { + "save_dir": logdir, + "name": "tensorboard", + } + }, + "testtube": { + "target": "pytorch_lightning.loggers.CSVLogger", + "params": { + "name": "testtube", + "save_dir": logdir, + } + }, + } + os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True) + default_logger_cfg = default_logger_cfgs["tensorboard"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + else: + logger_cfg = OmegaConf.create() + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + return logger_cfg + + +def get_trainer_strategy(lightning_config): + default_strategy_dict = { + "target": "pytorch_lightning.strategies.DDPShardedStrategy" + } + if "strategy" in lightning_config: + strategy_cfg = lightning_config.strategy + return strategy_cfg + else: + strategy_cfg = OmegaConf.create() + + strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg) + return strategy_cfg + + +def load_checkpoints(model, model_cfg): + if check_config_attribute(model_cfg, "pretrained_checkpoint"): + pretrained_ckpt = model_cfg.pretrained_checkpoint + assert os.path.exists( + pretrained_ckpt + ), "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt + mainlogger.info(">>> Load weights from pretrained checkpoint") + + pl_sd = torch.load(pretrained_ckpt, map_location="cpu") + try: + if 'state_dict' in pl_sd.keys(): + model.load_state_dict(pl_sd["state_dict"], strict=False) + mainlogger.info( + ">>> Loaded weights from pretrained checkpoint: %s" % + pretrained_ckpt) + else: + # deepspeed + new_pl_sd = OrderedDict() + for key in pl_sd['module'].keys(): + new_pl_sd[key[16:]] = pl_sd['module'][key] + model.load_state_dict(new_pl_sd, strict=False) + except: + model.load_state_dict(pl_sd) + else: + mainlogger.info(">>> Start training from scratch") + + return model + + +def set_logger(logfile, name='mainlogger'): + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logfile, mode='w') + fh.setLevel(logging.INFO) + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + fh.setFormatter( + logging.Formatter("%(asctime)s-%(levelname)s: %(message)s")) + ch.setFormatter(logging.Formatter("%(message)s")) + logger.addHandler(fh) + logger.addHandler(ch) + return logger + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters()) + + +def count_trainable_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def get_num_parameters(model): + models = [('World Model', model.model.diffusion_model), + ('Action Head', model.model.diffusion_model.action_unet), + ('State Head', model.model.diffusion_model.state_unet), + ('Total Trainable', model), + ('Total', model)] + + data = [] + for index, (name, model) in enumerate(models): + if name == "Total Trainable": + total_params = count_trainable_parameters(model) + else: + total_params = count_parameters(model) + if total_params < 0.1e9: + total_params_value = round(total_params / 1e6, 2) + unit = 'M' + else: + total_params_value = round(total_params / 1e9, 2) + unit = 'B' + + data.append({ + 'Model Name': name, + 'Params': f"{total_params_value} {unit}" + }) + + df = pd.DataFrame(data) + print(df) diff --git a/src/unifolm_wma/utils/utils.py b/src/unifolm_wma/utils/utils.py new file mode 100644 index 0000000..177d1bf --- /dev/null +++ b/src/unifolm_wma/utils/utils.py @@ -0,0 +1,81 @@ +import importlib +import numpy as np +import cv2 +import torch +import torch.distributed as dist + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print( + f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params." + ) + return total_params + + +def check_istarget(name, para_list): + """ + name: full name of source para + para_list: partial name of target para + """ + istarget = False + for para in para_list: + if para in name: + return True + return istarget + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def load_npz_from_dir(data_dir): + data = [ + np.load(os.path.join(data_dir, data_name))['arr_0'] + for data_name in os.listdir(data_dir) + ] + data = np.concatenate(data, axis=0) + return data + + +def load_npz_from_paths(data_paths): + data = [np.load(data_path)['arr_0'] for data_path in data_paths] + data = np.concatenate(data, axis=0) + return data + + +def resize_numpy_image(image, + max_resolution=512 * 512, + resize_short_edge=None): + h, w = image.shape[:2] + if resize_short_edge is not None: + k = resize_short_edge / min(h, w) + else: + k = max_resolution / (h * w) + k = k**0.5 + h = int(np.round(h * k / 64)) * 64 + w = int(np.round(w * k / 64)) * 64 + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) + return image + + +def setup_dist(args): + if dist.is_initialized(): + return + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group('nccl', init_method='env://') diff --git a/unitree_deploy/README.md b/unitree_deploy/README.md new file mode 100644 index 0000000..c0d1ba6 --- /dev/null +++ b/unitree_deploy/README.md @@ -0,0 +1,223 @@ +# Unitree Deploy + +
+

+ 🌎English | 🇨🇳中文 +

+
+ + + +This document provides instructions for setting up the deployment environment for Unitree G1 (with gripper) and Z1 platforms, including dependency installation, image service startup, and gripper control. + +# 0. 📖 Introduction + +This repository is used for model deployment with Unitree robots. + +--- + +# 1. 🛠️ Environment Setup + +```bash +conda create -n unitree_deploy python=3.10 && conda activate unitree_deploy + +conda install pinocchio -c conda-forge +pip install -e . + +# Optional: Install lerobot dependencies +pip install -e ".[lerobot]" + +git clone https://github.com/unitreerobotics/unitree_sdk2_python.git +cd unitree_sdk2_python && pip install -e . && cd .. +``` + +--- +# 2. 🚀 Start + +**Tip: Keep all devices on the same LAN** + +## 2.1 🤖 Run G1 with Dex_1 Gripper + +### 2.1.1 📷 Image Capture Service Setup (G1 Board) + +[To open the image_server, follow these steps](https://github.com/unitreerobotics/xr_teleoperate?tab=readme-ov-file#31-%EF%B8%8F-image-service) +1. Connect to the G1 board: + ```bash + ssh unitree@192.168.123.164 # Password: 123 + ``` + +2. Activate the environment and start the image server: + ```bash + conda activate tv + cd ~/image_server + python image_server.py + ``` + +--- + +### 2.1.2 🤏 Dex_1 Gripper Service Setup (Development PC2) + +Refer to the [Dex_1 Gripper Installation Guide](https://github.com/unitreerobotics/dex1_1_service?tab=readme-ov-file#1--installation) for detailed setup instructions. + +1. Navigate to the service directory: + ```bash + cd ~/dex1_1_service/build + ``` + +2. Start the gripper service, **ifconfig examines its own dds networkInterface**: + ```bash + sudo ./dex1_1_gripper_server --network eth0 -l -r + ``` + +3. Verify communication with the gripper service: + ```bash + ./test_dex1_1_gripper_server --network eth0 -l -r + ``` + +--- + +### 2.1.2 ✅Testing + +Perform the following tests to ensure proper functionality: + +- **Dex1 Gripper Test**: + ```bash + python test/endeffector/test_dex1.py + ``` + +- **G1 Arm Test**: + ```bash + python test/arm/g1/test_g1_arm.py + ``` + +- **Image Client Camera Test**: + ```bash + python test/camera/test_image_client_camera.py + ``` + +- **G1 Datasets Replay**: + ```bash + # --repo-id Your unique repo ID on Hugging Face Hub + # --robot_type The type of the robot e.g., z1_dual_dex1_realsense, z1_realsense, g1_dex1, + + python test/test_replay.py --repo-id unitreerobotics/G1_CameraPackaging_NewDataset --robot_type g1_dex1 + ``` +--- + +## 2.2 🦿 Run Z1 + +### 2.2.1 🦿 Z1 Setup +Clone and build the required repositories: + +1. Download [z1_controller](https://github.com/unitreerobotics/z1_controller.git) and [z1_sdk](https://github.com/unitreerobotics/z1_sdk.git). + +2. Build the repositories: + ```bash + mkdir build && cd build + cmake .. && make -j + ``` + +3. Copy the `unitree_arm_interface` library: [Modify according to your own path] + ```bash + cp z1_sdk/lib/unitree_arm_interface.cpython-310-x86_64-linux-gnu.so ./unitree_deploy/robot_devices/arm + ``` + +4. Start the Z1 controller [Modify according to your own path]: + ```bash + cd z1_controller/build && ./z1_ctrl + ``` + +--- + +### 2.2.2 Testing ✅ + +Run the following tests: + +- **Realsense Camera Test**: + ```bash + python test/camera/test_realsense_camera.py # Modify the corresponding serial number according to your realsense + ``` + +- **Z1 Arm Test**: + ```bash + python test/arm/z1/test_z1_arm.py + ``` + +- **Z1 Environment Test**: + ```bash + python test/arm/z1/test_z1_env.py + ``` + +- **Z1 Datasets Replay**: + ```bash + # --repo-id Your unique repo ID on Hugging Face Hub + # --robot_type The type of the robot e.g., z1_dual_dex1_realsense, z1_realsense, g1_dex1, + + python test/test_replay.py --repo-id unitreerobotics/Z1_StackBox_Dataset --robot_type z1_realsense + ``` +--- + +## 2.3 🦿 Run Z1_Dual + +### 2.3.1 🦿 Z1 Setup and Dex1 Setup +Clone and build the required repositories: + +1. Download and compile the corresponding code according to the above z1 steps and Download the gripper program to start locally + +2. [Modify the multi-machine control according to the document](https://support.unitree.com/home/zh/Z1_developer/sdk_operation) + +3. [Download the modified z1_sdk_1 and then compile it](https://github.com/unitreerobotics/z1_sdk/tree/z1_dual), Copy the `unitree_arm_interface` library: [Modify according to your own path] + ```bash + cp z1_sdk/lib/unitree_arm_interface.cpython-310-x86_64-linux-gnu.so ./unitree_deploy/robot_devices/arm + ``` + +4. Start the Z1 controller [Modify according to your own path]: + ```bash + cd z1_controller/builb && ./z1_ctrl + cd z1_controller_1/builb && ./z1_ctrl + ``` +5. Start the gripper service, **ifconfig examines its own dds networkInterface**: + ``` + sudo ./dex1_1_gripper_server --network eth0 -l -r + ``` +--- + +### 2.3.2 Testing ✅ + +Run the following tests: + +- **Z1_Dual Arm Test**: + ```bash + python test/arm/z1/test_z1_arm_dual.py + ``` + +- **Z1_Dual Datasets Replay**: + ```bash + # --repo-id Your unique repo ID on Hugging Face Hub + # --robot_type The type of the robot e.g., z1_dual_dex1_realsense, z1_realsense, g1_dex1, + + python test/test_replay.py --repo-id unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2 --robot_type z1_dual_dex1_realsense + ``` +--- + + +# 3.🧠 Inference and Deploy +1. [Modify the corresponding parameters according to your configuration](./unitree_deploy/robot/robot_configs.py) +2. Go back the **step-2 of Client Setup** under the [Inference and Deployment under Decision-Making Mode](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/README.md). + +# 4.🏗️ Code structure + +[If you want to add your own robot equipment, you can build it according to this document](./docs/GettingStarted.md) + + +# 5. 🤔 Troubleshooting + +For assistance, contact the project maintainer or refer to the respective GitHub repository documentation. 📖 + + +# 6. 🙏 Acknowledgement + +This code builds upon following open-source code-bases. Please visit the URLs to see the respective LICENSES (If you find these projects valuable, it would be greatly appreciated if you could give them a star rating.): + +1. https://github.com/huggingface/lerobot +2. https://github.com/unitreerobotics/unitree_sdk2_python diff --git a/unitree_deploy/docs/GettingStarted.md b/unitree_deploy/docs/GettingStarted.md new file mode 100644 index 0000000..dadb7b2 --- /dev/null +++ b/unitree_deploy/docs/GettingStarted.md @@ -0,0 +1,70 @@ +# Getting Started + +### Code framework + + +| Module Name | Documentation Link | +| ------------------------- | -------------------------------------------------- | +| robots | [build_robot](./build_robot.md) | +| robot_devices/arm | [add_robot_arm](./add_robot_arm.md) | +| robot_devices/cameras | [add_robot_camera](./add_robot_camera.md) | +| robot_devices/endeffector | [add_robot_endeffector](./add_robot_endeffector.md)| + +### Simple Usage (Example code, not executable) + +```python +import time +import math +import torch + +from unitree_deploy.robot.robot_utils import make_robot +from unitree_deploy.robot_devices.robots_devices_utils import precise_wait + +class YourPolicy: + def predict_action(self, observation, policy): + # Logic for predicting action + pass + +class UnitreeEnv: + def __init__(self): + self.robot = make_robot(self.robot_type) + if not self.robot.is_connected: + self.robot.connect() + # If disconnection is needed, call disconnect() here + # self.robot.disconnect() + + def get_obs(self): + # Get observation + observation = self.robot.capture_observation() + return observation + + def step(self, pred_action, t_command_target): + # Execute action + t_cycle_end = time.monotonic() + self.control_dt + t_command_target = t_cycle_end + self.control_dt + action = self.robot.send_action(torch.from_numpy(pred_action), t_command_target) + precise_wait(t_cycle_end) + return action + +if __name__ == "__main__": + policy = YourPolicy() # Create policy instance + env = UnitreeEnv() # Create environment instance + + t_start = time.monotonic() # Get start time + iter_idx = 0 # Initialize iteration index + control_dt = 1 / 30 # Control loop interval (30Hz) + + try: + while True: + t_cycle_end = t_start + (iter_idx + 1) * control_dt # Calculate end time of current cycle + t_command_target = t_cycle_end + control_dt # Calculate command target time + + observation = env.get_obs() # Get environment observation + pred_action = policy.predict_action(observation, policy) # Predict action + env.step(pred_action, t_command_target) # Execute action + + precise_wait(t_cycle_end) # Wait until cycle end + iter_idx += 1 # Update iteration index + finally: + # Perform cleanup operations on exit (e.g., disconnect robot) + pass diff --git a/unitree_deploy/docs/README_cn.md b/unitree_deploy/docs/README_cn.md new file mode 100644 index 0000000..b826f69 --- /dev/null +++ b/unitree_deploy/docs/README_cn.md @@ -0,0 +1,213 @@ +# Unitree Deploy + +本文档提供了为 Unitree G1 和 Z1 平台设置部署环境的说明,包括依赖安装、图像服务启动和夹爪控制。 + +# 0. 📖 简介 + +此代码库用于 Unitree 机器人模型的部署。 + +--- + +# 1. 🛠️ 环境设置 + +```bash +conda create -n unitree_deploy python=3.10 && conda activate unitree_deploy + +conda install pinocchio -c conda-forge +pip install -e . + +# 可选:安装 lerobot 依赖 +pip install -e ".[lerobot]" + +git clone https://github.com/unitreerobotics/unitree_sdk2_python.git +cd unitree_sdk2_python && pip install -e . && cd .. +``` + +--- +# 2. 🚀 启动 + +**提示:确保所有设备处于同一局域网内** + +## 2.1 🤖 运行 G1 和 Dex_1 夹爪 + +### 2.1.1 📷 图像捕获服务设置(G1 pc2) + +[按照以下步骤启动 image_server](https://github.com/unitreerobotics/xr_teleoperate?tab=readme-ov-file#31-%EF%B8%8F-image-service) +1. 连接到 G1: + ```bash + ssh unitree@192.168.123.164 # 密码:123 + ``` + +2. 激活环境并启动图像服务: + ```bash + conda activate tv + cd ~/image_server + python image_server.py + ``` + +--- + +### 2.1.2 🤏 Dex_1 夹爪服务设置(开发 PC2) + +参考 [Dex_1 夹爪安装指南](https://github.com/unitreerobotics/dex1_1_service?tab=readme-ov-file#1--installation) 获取详细设置说明。 + +1. 进入服务目录: + ```bash + cd ~/dex1_1_service/build + ``` + +2. 启动夹爪服务,**ifconfig 检查其自身的 dds 网络接口**: + ```bash + sudo ./dex1_1_gripper_server --network eth0 -l -r + ``` + +3. 验证与夹爪服务的通信: + ```bash + ./test_dex1_1_gripper_server --network eth0 -l -r + ``` + +--- + +### 2.1.3 ✅ 测试 + +执行以下测试以确保功能正常: + +- **Dex1 夹爪测试**: + ```bash + python test/endeffector/test_dex1.py + ``` + +- **G1 机械臂测试**: + ```bash + python test/arm/g1/test_g1_arm.py + ``` + +- **图像客户端相机测试**: + ```bash + python test/camera/test_image_client_camera.py + ``` + +- **G1 数据集回放**: + ```bash + # --repo-id Your unique repo ID on Hugging Face Hub + # --robot_type The type of the robot e.g., z1_dual_dex1_realsense, z1_realsense, g1_dex1, + + python test/test_replay.py --repo-id unitreerobotics/G1_CameraPackaging_NewDataset --robot_type g1_dex1 + ``` +--- + +## 2.2 🦿 运行 Z1 + +### 2.2.1 🦿 Z1 设置 +克隆并构建所需的代码库: + +1. 下载 [z1_controller](https://github.com/unitreerobotics/z1_controller.git) 和 [z1_sdk](https://github.com/unitreerobotics/z1_sdk.git)。 + +2. 构建代码库: + ```bash + mkdir build && cd build + cmake .. && make -j + ``` + +3. 复制 `unitree_arm_interface` 库:[根据您的路径修改] + ```bash + cp z1_sdk/lib/unitree_arm_interface.cpython-310-x86_64-linux-gnu.so ./unitree_deploy/robot_devices/arm + ``` + +4. 启动 Z1 控制器 [根据您的路径修改]: + ```bash + cd z1_controller/build && ./z1_ctrl + ``` + +--- + +### 2.2.2 ✅ 测试 + +运行以下测试: + +- **Realsense 相机测试**: + ```bash + python test/camera/test_realsense_camera.py # 根据您的 Realsense 修改对应的序列号 + ``` + +- **Z1 机械臂测试**: + ```bash + python test/arm/z1/test_z1_arm.py + ``` + +- **Z1 环境测试**: + ```bash + python test/arm/z1/test_z1_env.py + ``` + +- **Z1 数据集回放**: + ```bash + # --repo-id Your unique repo ID on Hugging Face Hub + # --robot_type The type of the robot e.g., z1_dual_dex1_realsense, z1_realsense, g1_dex1, + + python test/test_replay.py --repo-id unitreerobotics/Z1_StackBox_Dataset --robot_type z1_realsense + ``` +--- + +## 2.3 🦿 运行 Z1_Dual + +### 2.3.1 🦿 Z1 设置和 Dex1 设置 +克隆并构建所需的代码库: + +1. 按照上述 Z1 步骤下载并编译代码,并下载夹爪程序以本地启动。 + +2. [根据文档修改多机控制](https://support.unitree.com/home/zh/Z1_developer/sdk_operation) + +3. [下载修改后的 z1_sdk_1 并编译](https://github.com/unitreerobotics/z1_sdk/tree/z1_dual),复制 `unitree_arm_interface` 库:[根据您的路径修改] + ```bash + cp z1_sdk/lib/unitree_arm_interface.cpython-310-x86_64-linux-gnu.so ./unitree_deploy/robot_devices/arm + ``` + +4. 启动 Z1 控制器 [根据您的路径修改]: + ```bash + cd z1_controller/builb && ./z1_ctrl + cd z1_controller_1/builb && ./z1_ctrl + ``` +5. 启动夹爪服务,**ifconfig 检查其自身的 dds 网络接口**: + ``` + sudo ./dex1_1_gripper_server --network eth0 -l -r + ``` +--- + +### 2.3.2 ✅ 测试 + +运行以下测试: + +- **Z1_Dual 机械臂测试**: + ```bash + python test/arm/z1/test_z1_arm_dual.py + ``` + +- **Z1_Dual 数据集回放**: + ```bash + # --repo-id Your unique repo ID on Hugging Face Hub + # --robot_type The type of the robot e.g., z1_dual_dex1_realsense, z1_realsense, g1_dex1, + + python test/test_replay.py --repo-id unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2 --robot_type z1_dual_dex1_realsense + ``` +--- + + +# 3.🧠 推理与部署 +1. [根据您的配置修改相应参数](./unitree_deploy/robot/robot_configs.py) +2. 返回 [决策模式下的推理与部署](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/README.md) 中的 **客户端设置步骤 2**。 + +# 4.🏗️ 代码结构 + +[如果您想添加自己的机器人设备,可以根据此文档进行构建](./docs/GettingStarted.md) + +# 5. 🤔 故障排除 + +如需帮助,请联系项目维护人员或参考相应的 GitHub 仓库文档。📖 + +# 6. 🙏 致谢 + +此代码基于以下开源代码库构建。请访问相关 URL 查看相应的 LICENSES(如果您觉得这些项目有价值,请为它们点亮星星): + +1. https://github.com/huggingface/lerobot +2. https://github.com/unitreerobotics/unitree_sdk2_python diff --git a/unitree_deploy/docs/add_robot_arm.md b/unitree_deploy/docs/add_robot_arm.md new file mode 100644 index 0000000..900ab64 --- /dev/null +++ b/unitree_deploy/docs/add_robot_arm.md @@ -0,0 +1,76 @@ +# How to Build Your Own Arm + +### Define your own config for the robot arm (unitree_deploy/robot_devices/arm/config.py) + +```python +@ArmConfig.register_subclass("z1") # Register your custom arm wrapper. Here use def __init__(self, config: Z1DualArmConfig): +@dataclass +class Z1ArmConfig(ArmConfig): + port: str + motors: dict[str, tuple[int, str]] + mock: bool = False + init_pose_left: list = None + init_pose_right: list = None + control_dt: float = 1/500.0 + +# Default parameters go first [parameters that may need to be customized], +# Non-default parameters go later [fixed parameters] +``` + +### Description of methods in your arm class (unitree_deploy/robot_devices/arm/utils.py) + +```python +# Base class for Arm, extensible with required methods + +class Arm(Protocol): + def connect(self): ... + def disconnect(self): ... + def motor_names(self): ... + + def read_current_motor_q(self): ... + def read_current_arm_q(self): ... + def read_current_arm_dq(self): ... + def write_arm(self): ... + + def arm_ik(self): ... +``` + +How to implement external calls? +Use make_arm_motors_buses_from_configs [based on the config file] to construct the UnitreeRobot class. +Use make_arm_motors_bus [based on arm_type] which is generally used for external module loading. + +### Implementation of the arm class (unitree_deploy/robot_devices/arm/.../....py) + +```python + # These methods need to be implemented and completed + def connect(self): ... + def disconnect(self): ... + def motor_names(self): ... + # connect() and disconnect() should handle initialization and homing respectively + + def read_current_motor_q(self): ... + def read_current_arm_q(self): ... + def read_current_arm_dq(self): ... + # Outputs should be unified as np.ndarray + + def write_arm(self): ... + # Write control commands here + + def arm_ik(self): ... + # Wrap IK into your own arm class for external calling + + # Private/protected properties [for reading motor names, IDs, etc.] + @property + def motor_names(self) -> list[str]: + return list(self.motors.keys()) + + @property + def motor_models(self) -> list[str]: + return [model for _, model in self.motors.values()] + + @property + def motor_indices(self) -> list[int]: + return [idx for idx, _ in self.motors.values()] +``` + +All arms use threading to implement \_subscribe_motor_state and \_ctrl_motor_state threads for internal reading and writing within the class. diff --git a/unitree_deploy/docs/add_robot_camera.md b/unitree_deploy/docs/add_robot_camera.md new file mode 100644 index 0000000..d006c73 --- /dev/null +++ b/unitree_deploy/docs/add_robot_camera.md @@ -0,0 +1,66 @@ +# How to build your own cameras + +### Define your own config for cameras (unitree_deploy/robot_devices/cameras/config.py) + +```python +@CameraConfig.register_subclass("opencv") # Define and wrap your own cameras. Here use def __init__(self, config: OpenCVCameraConfig): +@dataclass +class OpenCVCameraConfig(CameraConfig): + """ + Example of tested options for Intel Real Sense D405: + + OpenCVCameraConfig(0, 30, 640, 480) + OpenCVCameraConfig(0, 60, 640, 480) + OpenCVCameraConfig(0, 90, 640, 480) + OpenCVCameraConfig(0, 30, 1280, 720) + + """ + # Define the required camera parameters + camera_index: int + fps: int | None = None + width: int | None = None + height: int | None = None + color_mode: str = "rgb" + channels: int | None = None + rotation: int | None = None + mock: bool = False + + def __post_init__(self): + if self.color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." + ) + + self.channels = 3 + + if self.rotation not in [-90, None, 90, 180]: + raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") + +# Default parameters go first [parameters that need to be customized], +# Non-default parameters go later [fixed parameters] +``` + +### Description of methods in your cameras class (unitree_deploy/robot_devices/cameras/utils.py) + +```python +# Base class for cameras, extensible with required methods + +class Camera(Protocol): + def connect(self): ... + def read(self, temporary_color: str | None = None) -> np.ndarray: ... # Single-threaded reading + def async_read(self) -> np.ndarray: ... # Multi-threaded + def disconnect(self): ... +``` + +How can external modules implement calls? Use **make_cameras_from_configs [based on configuration files]** to construct the `UnitreeRobot` class. +**make_camera [based on camera_type]** is generally used for external module loading. + +### Implementation of the `camera` class (unitree_deploy/robot_devices/camera/.../....py) + +```python + # These need to be completed, focusing on implementing these two parts + def read(self, temporary_color: str | None = None) -> np.ndarray: ... # Single-threaded reading + def async_read(self) -> np.ndarray: ... # Multi-threaded +``` + +All cameras use threading to implement `async_read` for internal read and write operations. diff --git a/unitree_deploy/docs/add_robot_endeffector.md b/unitree_deploy/docs/add_robot_endeffector.md new file mode 100644 index 0000000..9680d26 --- /dev/null +++ b/unitree_deploy/docs/add_robot_endeffector.md @@ -0,0 +1,77 @@ +# How to Build Your Own End-Effector [Currently dex_1 and dex_3 are available] + +### Define your own config for the end-effector (unitree_deploy/robot_devices/endeffector/config.py) + +```python +@EndEffectorConfig.register_subclass("gripper") # Register your custom end-effector wrapper. Here it uses def __init__(self, config: GripperConfig): +@dataclass +class GripperConfig(EndEffectorConfig): + motors: dict[str, tuple[int, str]] + unit_test: bool = False + control_dt: float = 1/200 + mock: bool = False + + def __post_init__(self): + if self.control_dt < 0.002: + raise ValueError(f"`control_dt` must > 1/500 (got {self.control_dt})") + +# Default arguments should be placed first [parameters that may need to be customized], +# Non-default arguments should be placed later [fixed or less important parameters]. +``` + +### Description of methods in your end-effector class (unitree_deploy/robot_devices/endeffector/utils.py) + +```python +# Base class for EndEffector, extend with required methods + +class EndEffector(Protocol): + def connect(self): ... + def disconnect(self): ... + def motor_names(self): ... + + def read_current_endeffector_q(self): ... + def read_current_endeffector_dq(self): ... + def write_endeffector(self): ... + + def endeffector_ik(self): ... +``` + +How to call externally? +Use make_endeffector_motors_buses_from_configs → Construct the UnitreeRobot class based on the config file +Use make_endeffector_motors_bus → Construct based on endeffector_type (typically for external module loading) + +### Implementation of your end-effector class (unitree_deploy/robot_devices/endeffector/.../....py) + +```python + # These methods need to be implemented and completed + def connect(self): ... + def disconnect(self): ... + def motor_names(self): ... + # connect() and disconnect() should handle initialization and homing respectively + + def read_current_endeffector_q(self): ... + def read_current_endeffector_dq(self): ... + # Outputs should be unified as np.ndarray + + def write_endeffector(self): ... + # Write control commands here + + def arm_ik(self): ... + # Wrap IK into your own arm class, to be called externally + + # Private/protected properties + # (for reading motor names, IDs, etc. These will be used in UnitreeRobot class for dataset encapsulation) + @property + def motor_names(self) -> list[str]: + return list(self.motors.keys()) + + @property + def motor_models(self) -> list[str]: + return [model for _, model in self.motors.values()] + + @property + def motor_indices(self) -> list[int]: + return [idx for idx, _ in self.motors.values()] +``` + +For arms, use threading to implement \_subscribe_gripper_motor_state (thread for reading motor states),\_ctrl_gripper_motor (thread for motor control),Both threads should run internally within the class. diff --git a/unitree_deploy/docs/build_robot.md b/unitree_deploy/docs/build_robot.md new file mode 100644 index 0000000..b90b9a7 --- /dev/null +++ b/unitree_deploy/docs/build_robot.md @@ -0,0 +1,140 @@ +# Build your own robot + +### Add your own config ((unitree_deploy/robot/robot_configs.py)) + +The base class of robot config is defined as **UnitreeRobotConfig** + +```python +class UnitreeRobotConfig(RobotConfig): + cameras: dict[str, CameraConfig] = field(default_factory=lambda: {}) # Corresponding to your own camera + arm: dict[str, ArmConfig] = field(default_factory=lambda: {}) # Corresponding to your own arm + endeffector: dict[str, EndEffectorConfig] = field(default_factory=lambda: {}) # Corresponding to your own end-effector + + mock: bool = False # Simulation [To be implemented, for debugging, to check some class definitions and message type formats] +``` + +Specific example: separately fill in \[name\]:robot_devies → cameras, +arm, endeffector.\ +If not provided, they default to empty and will not affect the system.\ +(In principle, different robot_devies and different quantities can be +constructed.) + +```python +class Z1dual_Dex1_Opencv_RobotConfig(UnitreeRobotConfig): + + # Troubleshooting: If one of your IntelRealSense cameras freezes during + # data recording due to bandwidth limit, you might need to plug the camera + # into another USB hub or PCIe card. + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { # Add corresponding configs for different cameras [name]:OpenCVCameraConfig + required parameters + "cam_high": OpenCVCameraConfig( + camera_index="/dev/video0", + fps=30, + width=640, + height=480, + ), + "cam_left_wrist": OpenCVCameraConfig( + camera_index="/dev/video2", + fps=30, + width=640, + height=480, + ), + "cam_right_wrist": OpenCVCameraConfig( + camera_index="/dev/video4", + fps=30, + width=640, + height=480, + ), + } + ) + + arm: dict[str, ArmConfig] = field( + default_factory=lambda: { + "z1_dual": Z1DualArmConfig( # Add corresponding configs for different arms [name]:Z1DualArmConfig + required parameters + unit_test = False, + init_pose_left = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + init_pose_right = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + control_dt = 1/500.0, + motors={ + # name: (index, model) + "kLeftWaist": [0, "z1-joint"], + "kLeftShoulder": [1, "z1-joint"], + "kLeftElbow": [2, "z1-joint"], + "kLeftForearmRoll": [3, "z1-joint"], + "kLeftWristAngle": [4, "z1-joint"], + "kLeftWristRotate": [5, "z1-joint"], + + "kRightWaist": [7, "z1-joint"], + "kRightShoulder": [8, "z1-joint"], + "kRightElbow": [9, "z1-joint"], + "kRightForearmRoll": [10, "z1-joint"], + "kRightWristAngle": [11, "z1-joint"], + "kRightWristRotate": [12, "z1-joint"], + }, + ), + } + ) + + endeffector: dict[str, EndEffectorConfig] = field( + default_factory=lambda: { + "gripper": GripperConfig( # Add corresponding configs for different end-effectors [name]:GripperConfig + required parameters + unit_test = False, + unit_test = True, + control_dt = 1/250, + motors={ + # name: (index, model) + "kLeftGripper": [0, "z1_gripper-joint"], + "kRightGripper": [1, "z1_gripper-joint"], + }, + ), + } + ) + + mock: bool = False +``` + +--- + +### robot utils ((unitree_deploy/robot/utils.py)) + +```python +Implementation of the Robot base class + +class Robot(Protocol): + robot_type: str + features: dict + + def connect(self): ... # Connect devices (including cameras, arms, end-effectors of robot_devies) + def capture_observation(self): ... # capture_observation (Get current state, including data from camera + arm + end-effector) + def send_action(self, action): ... # send_action (Send action to arm + end-effector actuators, can be used for model inference and data replay) + def disconnect(self): ... # Disconnect devices +``` + +External calls **make_robot_from_config** and **make_robot** are used in +**control_robot**, to initialize the robot and implement specific +functions. + +--- + +### manipulator ((unitree_deploy/robot/manipulator.py)) + +UnitreeRobot implements initialization by calling +**UnitreeRobotConfig**. + +```python + Several important parts of the implementation + + def capture_observation(self): # Get current obs, return { observation.state, observation.images} + + def send_action( # Model inference and data replay, receives action + time + self, action: torch.Tensor, t_command_target:float|None = None + ) -> torch.Tensor: + + # Here we input device data + # Output (arm + end-effector) joint angle positions, end-effector positions, or other data conversion (IK is implemented here!) + # Output is uniformly converted into joint angle positions {"left":arm_joint_points, "roght":arm_joint_points} + {"left":endeffector_joint_points, "roght":endeffector_joint_points} + # Why consider left and right? Because this separates single-arm cases, and different arms and different end-effectors. + # This way, the implementation can work properly. + def convert_data_based_on_robot_type(self, robot_type: str, leader_pos: dict[str, np.ndarray] + ) -> None | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: +``` diff --git a/unitree_deploy/pyproject.toml b/unitree_deploy/pyproject.toml new file mode 100644 index 0000000..e0aabb3 --- /dev/null +++ b/unitree_deploy/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "unitree_deploy" +version = "0.0.3" +description = "unitree deploy" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "BSD-3-Clause" } +authors = [ + { name = "hengguo", email = "rd_gh@unitree.com" } +] +keywords = ["unitree", "robotics", "deployment"] + +dependencies = [ + "tyro", + "draccus", + "datasets==3.3.0", + "meshcat", + "pyrealsense2", + "numpy", + "opencv-python", + "mujoco", + "matplotlib", + "dm_env", + "torch>=2.2.1,<2.8.0", + "rerun-sdk>=0.21.0,<0.23.0", +] +[tool.setuptools] +packages = ["unitree_deploy"] + +[project.optional-dependencies] +lerobot = [ + "lerobot @ git+https://github.com/huggingface/lerobot.git@0878c68" +] + +[tool.ruff] +line-length = 110 +target-version = "py310" +exclude = ["build", "venv", "__pycache__"] +fix = true +show-fixes = true + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] +ignore = ["N801"] + +[tool.ruff.lint.per-file-ignores] +"arm_indexs.py" = ["N815"] \ No newline at end of file diff --git a/unitree_deploy/scripts/robot_client.py b/unitree_deploy/scripts/robot_client.py new file mode 100644 index 0000000..da98126 --- /dev/null +++ b/unitree_deploy/scripts/robot_client.py @@ -0,0 +1,198 @@ +import argparse +import os +import time +import cv2 +import numpy as np +import torch +import tqdm + +from typing import Any, Deque, MutableMapping, OrderedDict +from collections import deque +from pathlib import Path + +from unitree_deploy.real_unitree_env import make_real_env +from unitree_deploy.utils.eval_utils import ( + ACTTemporalEnsembler, + LongConnectionClient, + populate_queues, +) + +# ----------------------------------------------------------------------------- +# Network & environment defaults +# ----------------------------------------------------------------------------- +os.environ["http_proxy"] = "" +os.environ["https_proxy"] = "" +HOST = "127.0.0.1" +PORT = 8000 +BASE_URL = f"http://{HOST}:{PORT}" + +# fmt: off +INIT_POSE = { + 'g1_dex1': np.array([0.10559805, 0.02726714, -0.01210221, -0.33341318, -0.22513399, -0.02627627, -0.15437093, 0.1273793 , -0.1674708 , -0.11544029, -0.40095493, 0.44332668, 0.11566751, 0.3936641, 5.4, 5.4], dtype=np.float32), + 'z1_dual_dex1_realsense': np.array([-1.0262332, 1.4281361, -1.2149128, 0.6473399, -0.12425245, 0.44945636, 0.89584476, 1.2593982, -1.0737865, 0.6672816, 0.39730102, -0.47400007, 0.9894176, 0.9817477 ], dtype=np.float32), + 'z1_realsense': np.array([-0.06940782, 1.4751548, -0.7554075, 1.0501366, 0.02931615, -0.02810347, -0.99238837], dtype=np.float32), +} +ZERO_ACTION = { + 'g1_dex1': torch.zeros(16, dtype=torch.float32), + 'z1_dual_dex1_realsense': torch.zeros(14, dtype=torch.float32), + 'z1_realsense': torch.zeros(7, dtype=torch.float32), +} +CAM_KEY = { + 'g1_dex1': 'cam_right_high', + 'z1_dual_dex1_realsense': 'cam_high', + 'z1_realsense': 'cam_high', +} +# fmt: on + + +def prepare_observation(args: argparse.Namespace, obs: Any) -> OrderedDict: + """ + Convert a raw env observation into the model's expected input dict. + """ + rgb_image = cv2.cvtColor( + obs.observation["images"][CAM_KEY[args.robot_type]], cv2.COLOR_BGR2RGB) + observation = { + "observation.images.top": + torch.from_numpy(rgb_image).permute(2, 0, 1), + "observation.state": + torch.from_numpy(obs.observation["qpos"]), + "action": ZERO_ACTION[args.robot_type], + } + return OrderedDict(observation) + + +def run_policy( + args: argparse.Namespace, + env: Any, + client: LongConnectionClient, + temporal_ensembler: ACTTemporalEnsembler, + cond_obs_queues: MutableMapping[str, Deque[torch.Tensor]], + output_dir: Path, +) -> None: + """ + Single rollout loop: + 1) warm start the robot, + 2) stream observations, + 3) fetch actions from the policy server, + 4) execute with temporal ensembling for smoother control. + """ + + _ = env.step(INIT_POSE[args.robot_type]) + time.sleep(2.0) + t = 0 + + while True: + # Gapture observation + obs = env.get_observation(t) + # Format observation + obs = prepare_observation(args, obs) + cond_obs_queues = populate_queues(cond_obs_queues, obs) + # Call server to get actions + pred_actions = client.predict_action(args.language_instruction, + cond_obs_queues).unsqueeze(0) + # Keep only the next horizon of actions and apply temporal ensemble smoothing + actions = temporal_ensembler.update( + pred_actions[:, :args.action_horizon])[0] + + # Execute the actions + for n in range(args.exe_steps): + action = actions[n].cpu().numpy() + print(f">>> Exec => step {n} action: {action}", flush=True) + print("---------------------------------------------") + + # Maintain real-time loop at `control_freq` Hz + t1 = time.time() + obs = env.step(action) + time.sleep(max(0, 1 / args.control_freq - time.time() + t1)) + t += 1 + + # Prime the queue for the next action step (except after the last one in this chunk) + if n < args.exe_steps - 1: + obs = prepare_observation(args, obs) + cond_obs_queues = populate_queues(cond_obs_queues, obs) + + +def run_eval(args: argparse.Namespace) -> None: + client = LongConnectionClient(BASE_URL) + + # Initialize ACT temporal moving-averge smoother + temporal_ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff=0.01, + chunk_size=args.action_horizon, + exe_steps=args.exe_steps) + temporal_ensembler.reset() + + # Initialize observation and action horizon queue + cond_obs_queues = { + "observation.images.top": deque(maxlen=args.observation_horizon), + "observation.state": deque(maxlen=args.observation_horizon), + "action": deque( + maxlen=16), # NOTE: HAND CODE AS THE MODEL PREDCIT FUTURE 16 STEPS + } + + env = make_real_env( + robot_type=args.robot_type, + dt=1 / args.control_freq, + ) + env.connect() + + try: + for episode_idx in tqdm.tqdm(range(0, args.num_rollouts_planned)): + output_dir = Path(args.output_dir) / f"episode_{episode_idx:03d}" + output_dir.mkdir(parents=True, exist_ok=True) + run_policy(args, env, client, temporal_ensembler, cond_obs_queues, + output_dir) + finally: + env.close() + env.close() + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--robot_type", + type=str, + default="g1_dex1", + help="The type of the robot embodiment.") + parser.add_argument( + "--action_horizon", + type=int, + default=16, + help="Number of future actions, predicted by the policy, to keep", + ) + parser.add_argument( + "--exe_steps", + type=int, + default=16, + help= + "Number of future actions to execute, which must be less than the above action horizon.", + ) + parser.add_argument( + "--observation_horizon", + type=int, + default=2, + help="Number of most recent frames/states to consider.", + ) + parser.add_argument( + "--language_instruction", + type=str, + default="Pack black camera into box", + help="The language instruction provided to the policy server.", + ) + parser.add_argument("--num_rollouts_planned", + type=int, + default=10, + help="The number of rollouts to run.") + parser.add_argument("--output_dir", + type=str, + default="./results", + help="The directory for saving results.") + parser.add_argument("--control_freq", + type=float, + default=30, + help="The Low-level control frequency in Hz.") + return parser + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + run_eval(args) diff --git a/unitree_deploy/test/arm/g1/test_g1_arm.py b/unitree_deploy/test/arm/g1/test_g1_arm.py new file mode 100644 index 0000000..bed16b1 --- /dev/null +++ b/unitree_deploy/test/arm/g1/test_g1_arm.py @@ -0,0 +1,105 @@ +import time + +import numpy as np +import pinocchio as pin + +from unitree_deploy.robot.robot_configs import g1_motors +from unitree_deploy.robot_devices.arm.configs import G1ArmConfig +from unitree_deploy.robot_devices.arm.utils import make_arm_motors_buses_from_configs +from unitree_deploy.robot_devices.robots_devices_utils import precise_wait + +if __name__ == "__main__": + # ============== Arm Configuration ============== + def g1_dual_arm_default_factory(): + return { + "g1": G1ArmConfig( + init_pose=np.zeros(14), + motors=g1_motors, + mock=False, + ), + } + + # ============================================== + + # Initialize and connect to the robotic arm + arm = make_arm_motors_buses_from_configs(g1_dual_arm_default_factory()) + for name in arm: + arm[name].connect() + time.sleep(1.5) + print("✅ Arm connected. Waiting to start...") + + # Define initial target poses for left and right arms + L_tf_target = pin.SE3( + pin.Quaternion(1, 0, 0, 0), + np.array([0.25, +0.25, 0.1]), + ) + + R_tf_target = pin.SE3( + pin.Quaternion(1, 0, 0, 0), + np.array([0.25, -0.25, 0.1]), + ) + + rotation_speed = 0.005 # Rotation speed in radians per iteration + + # Motion parameters + control_dt = 1 / 50 # Control cycle duration (20ms) + step = 0 + max_step = 240 + + initial_data_received = True # Used to switch from drive to schedule mode + # Wait for user input to start the motion loop + user_input = input("Please enter the start signal (enter 's' to start the subsequent program): \n") + if user_input.lower() == "s": + try: + while True: + # Define timing for the control cycle + t_cycle_end = time.monotonic() + control_dt + t_command_target = t_cycle_end + control_dt + + direction = 1 if step <= 120 else -1 + angle = rotation_speed * (step if step <= 120 else (240 - step)) + + cos_half_angle = np.cos(angle / 2) + sin_half_angle = np.sin(angle / 2) + + L_quat = pin.Quaternion(cos_half_angle, 0, sin_half_angle, 0) # 绕 Y 轴旋转 + R_quat = pin.Quaternion(cos_half_angle, 0, 0, sin_half_angle) # 绕 Z 轴旋转 + + delta_l = np.array([0.001, 0.001, 0.001]) * direction + delta_r = np.array([0.001, -0.001, 0.001]) * direction + + L_tf_target.translation += delta_l + R_tf_target.translation += delta_r + + L_tf_target.rotation = L_quat.toRotationMatrix() + R_tf_target.rotation = R_quat.toRotationMatrix() + + # Solve inverse kinematics for the arm + for name in arm: + sol_q, sol_tauff = arm[name].arm_ik(L_tf_target.homogeneous, R_tf_target.homogeneous) + print(f"Arm {name} solution: q={sol_q}, tauff={sol_tauff}") + # Determine command mode + cmd_target = "drive_to_waypoint" if initial_data_received else "schedule_waypoint" + + # Send joint target command to arm + arm[name].write_arm( + q_target=sol_q, + tauff_target=sol_tauff, # Optional: send torque feedforward + time_target=t_command_target - time.monotonic() + time.perf_counter(), + cmd_target="schedule_waypoint", + ) + + # Update step and reset after full cycle + step = (step + 1) % (max_step + 1) + initial_data_received = False + + # Wait until end of control cycle + precise_wait(t_cycle_end) + + except KeyboardInterrupt: + # Handle Ctrl+C to safely disconnect + print("\n🛑 Ctrl+C detected. Disconnecting arm...") + for name in arm: + arm[name].disconnect() + + print("✅ Arm disconnected. Exiting.") diff --git a/unitree_deploy/test/arm/g1/test_g1_env.py b/unitree_deploy/test/arm/g1/test_g1_env.py new file mode 100644 index 0000000..87e68a9 --- /dev/null +++ b/unitree_deploy/test/arm/g1/test_g1_env.py @@ -0,0 +1,91 @@ +import math +import time + +import numpy as np +import pinocchio as pin + +from unitree_deploy.real_unitree_env import make_real_env +from unitree_deploy.utils.rerun_visualizer import RerunLogger, flatten_images, visualization_data +from unitree_deploy.utils.rich_logger import log_info +from unitree_deploy.utils.trajectory_generator import sinusoidal_gripper_motion + +if __name__ == "__main__": + period = 2.0 + motion_period = 2.0 + motion_amplitude = 0.99 + + rerun_logger = RerunLogger() + env = make_real_env(robot_type="g1_dex1", dt=1 / 30) + env.connect() + + # Define initial target poses for left and right arms + L_tf_target = pin.SE3( + pin.Quaternion(1, 0, 0, 0), + np.array([0.25, +0.25, 0.1]), + ) + + R_tf_target = pin.SE3( + pin.Quaternion(1, 0, 0, 0), + np.array([0.25, -0.25, 0.1]), + ) + + rotation_speed = 0.005 # Rotation speed in radians per iteration + + # Motion parameters + control_dt = 1 / 50 # Control cycle duration (20ms) + step = 0 + max_step = 240 + + initial_data_received = True # Used to switch from drive to schedule mode + # Wait for user input to start the motion loop + user_input = input("Please enter the start signal (enter 's' to start the subsequent program): \n") + if user_input.lower() == "s": + try: + current_time = math.pi / 2 + idx = 0 # Initialize index for logging + while True: + # Define timing for the control cycle + t_cycle_end = time.monotonic() + control_dt + t_command_target = t_cycle_end + control_dt + + direction = 1 if step <= 120 else -1 + angle = rotation_speed * (step if step <= 120 else (240 - step)) + + cos_half_angle = np.cos(angle / 2) + sin_half_angle = np.sin(angle / 2) + + L_quat = pin.Quaternion(cos_half_angle, 0, sin_half_angle, 0) # 绕 Y 轴旋转 + R_quat = pin.Quaternion(cos_half_angle, 0, 0, sin_half_angle) # 绕 Z 轴旋转 + + delta_l = np.array([0.001, 0.001, 0.001]) * direction + delta_r = np.array([0.001, -0.001, 0.001]) * direction + + L_tf_target.translation += delta_l + R_tf_target.translation += delta_r + + L_tf_target.rotation = L_quat.toRotationMatrix() + R_tf_target.rotation = R_quat.toRotationMatrix() + + # Solve inverse kinematics for the left arm + for arm_name in env.robot.arm: + arm_sol_q, arm_sol_tauff = env.robot.arm[arm_name].arm_ik( + L_tf_target.homogeneous, R_tf_target.homogeneous + ) + + gripper_target_q = sinusoidal_gripper_motion( + period=motion_period, amplitude=motion_amplitude, current_time=time.perf_counter() + ) + action = np.concatenate([arm_sol_q, gripper_target_q], axis=0) + step_type, reward, _, observation = env.step(action) + + idx += 1 + visualization_data(idx, flatten_images(observation), observation["qpos"], arm_sol_q, rerun_logger) + + # Update step and reset after full cycle + current_time += control_dt + step = (step + 1) % (max_step + 1) + + except KeyboardInterrupt: + # Handle Ctrl+C to safely disconnect + log_info("\n🛑 Ctrl+C detected. Disconnecting arm...") + env.close() diff --git a/unitree_deploy/test/arm/z1/test_z1_arm.py b/unitree_deploy/test/arm/z1/test_z1_arm.py new file mode 100644 index 0000000..5f2e3ef --- /dev/null +++ b/unitree_deploy/test/arm/z1/test_z1_arm.py @@ -0,0 +1,81 @@ +import math +import time + +import numpy as np +import pinocchio as pin + +from unitree_deploy.robot.robot_configs import z1_motors +from unitree_deploy.robot_devices.arm.utils import make_arm_motors_bus +from unitree_deploy.robot_devices.robots_devices_utils import precise_wait +from unitree_deploy.utils.trajectory_generator import generate_rotation, sinusoidal_gripper_motion + +if __name__ == "__main__": + # ============== Arm Configuration ============== + arm_type = "z1" + arm_kwargs = { + "arm_type": arm_type, + "init_pose": [0.00623, 1.11164, -0.77531, -0.32167, -0.005, 0.0, 0.0], # Initial joint pose + "motors": z1_motors, + } + # ============================================== + + # Initialize and connect to the robotic arm + arm = make_arm_motors_bus(**arm_kwargs) + arm.connect() + time.sleep(2) + print("✅ Arm connected. Waiting to start...") + + # Define arm initial target poses + arm_tf_target = pin.SE3(pin.Quaternion(1, 0, 0, 0), np.array([0.2, 0, 0.4])) + + # Motion parameters + rotation_speed = 0.01 # Rotation speed (rad per step) + control_dt = 1 / 30 # Control cycle duration (20ms) + step = 0 + max_step = 240 # Full motion cycle + + # Wait for user input to start the motion loop + user_input = input("Please enter the start signal (enter 's' to start the subsequent program): \n") + if user_input.lower() == "s": + try: + current_time = math.pi / 2 + while True: + # Define timing for the control cycle + t_cycle_end = time.monotonic() + control_dt + t_command_target = t_cycle_end + control_dt + + # Generate target rotation and translation + L_quat, R_quat, delta_l, delta_r = generate_rotation(step, rotation_speed, max_step) + arm_tf_target.translation += delta_l + # delta_r is not used in this context + arm_tf_target.rotation = L_quat.toRotationMatrix() + + # Solve inverse kinematics for the left arm + arm_sol_q, arm_sol_tauff = arm.arm_ik(arm_tf_target.homogeneous) + + # Generate sinusoidal motion for the gripper + target_gripper = ( + sinusoidal_gripper_motion(period=4.0, amplitude=0.99, current_time=current_time) - 1 + ) # Adjust target_q by subtracting 1 + + target_arm = np.concatenate((arm_sol_q, target_gripper), axis=0) # Add a zero for the gripper + + arm.write_arm( + q_target=target_arm, + # tauff_target=left_sol_tauff, # Optional: send torque feedforward + time_target=t_command_target - time.monotonic() + time.perf_counter(), + cmd_target="schedule_waypoint", + ) + + # Update step and reset after full cycle + step = (step + 1) % (max_step + 1) + current_time += control_dt + + # Wait until end of control cycle + precise_wait(t_cycle_end) + + except KeyboardInterrupt: + # Handle Ctrl+C to safely disconnect + print("\n🛑 Ctrl+C detected. Disconnecting arm...") + arm.disconnect() + print("✅ Arm disconnected. Exiting.") diff --git a/unitree_deploy/test/arm/z1/test_z1_dual_arm.py b/unitree_deploy/test/arm/z1/test_z1_dual_arm.py new file mode 100644 index 0000000..f9ba0be --- /dev/null +++ b/unitree_deploy/test/arm/z1/test_z1_dual_arm.py @@ -0,0 +1,112 @@ +import time + +import numpy as np +import pinocchio as pin + +from unitree_deploy.robot_devices.arm.configs import Z1DualArmConfig +from unitree_deploy.robot_devices.arm.utils import make_arm_motors_buses_from_configs +from unitree_deploy.robot_devices.robots_devices_utils import precise_wait +from unitree_deploy.utils.trajectory_generator import generate_rotation + +if __name__ == "__main__": + # ============== Arm Configuration ============== + def z1_dual_arm_single_config_factory(): + return { + "z1_dual": Z1DualArmConfig( + left_robot_ip="127.0.0.1", + left_robot_port1=8073, + left_robot_port2=8074, + right_robot_ip="127.0.0.1", + right_robot_port1=8071, + right_robot_port2=8072, + init_pose_left=[0, 0, 0, 0, 0, 0], + init_pose_right=[0, 0, 0, 0, 0, 0], + control_dt=1 / 250.0, + motors={ + # name: (index, model) + "kLeftWaist": [0, "z1-joint"], + "kLeftShoulder": [1, "z1-joint"], + "kLeftElbow": [2, "z1-joint"], + "kLeftForearmRoll": [3, "z1-joint"], + "kLeftWristAngle": [4, "z1-joint"], + "kLeftWristRotate": [5, "z1-joint"], + "kRightWaist": [7, "z1-joint"], + "kRightShoulder": [8, "z1-joint"], + "kRightElbow": [9, "z1-joint"], + "kRightForearmRoll": [10, "z1-joint"], + "kRightWristAngle": [11, "z1-joint"], + "kRightWristRotate": [12, "z1-joint"], + }, + ), + } + + # ============================================== + + # Initialize and connect to the robotic arm + arm = make_arm_motors_buses_from_configs(z1_dual_arm_single_config_factory()) + for name in arm: + arm[name].connect() + time.sleep(1.5) + + print("✅ Arm connected. Waiting to start...") + + # Define initial target poses for left and right arms + L_tf_target = pin.SE3(pin.Quaternion(1, 0, 0, 0), np.array([0.2, 0, 0.4])) + R_tf_target = pin.SE3(pin.Quaternion(1, 0, 0, 0), np.array([0.2, 0, 0.3])) + + # Motion parameters + rotation_speed = 0.01 # Rotation speed (rad per step) + control_dt = 1 / 30 # Control cycle duration (20ms) + step = 0 + max_step = 240 # Full motion cycle + initial_data_received = True # Used to switch from drive to schedule mode + + # Wait for user input to start the motion loop + user_input = input("Please enter the start signal (enter 's' to start the subsequent program): \n") + if user_input.lower() == "s": + try: + while True: + # Define timing for the control cycle + t_cycle_end = time.monotonic() + control_dt + t_command_target = t_cycle_end + control_dt + + # Generate target rotation and translation + L_quat, R_quat, delta_l, delta_r = generate_rotation(step, rotation_speed, max_step) + + # Apply translation deltas to target pose + L_tf_target.translation += delta_l + R_tf_target.translation += delta_r + + # Apply rotation to target pose + L_tf_target.rotation = L_quat.toRotationMatrix() + R_tf_target.rotation = R_quat.toRotationMatrix() + + # Solve inverse kinematics for the left arm + for name in arm: + sol_q, sol_tauff = arm[name].arm_ik(L_tf_target.homogeneous, R_tf_target.homogeneous) + + # Determine command mode + cmd_target = "drive_to_waypoint" if initial_data_received else "schedule_waypoint" + + # Send joint target command to arm + for name in arm: + arm[name].write_arm( + q_target=sol_q, + # tauff_target=sol_tauff, # Optional: send torque feedforward + time_target=t_command_target - time.monotonic() + time.perf_counter(), + cmd_target=cmd_target, + ) + + # Update step and reset after full cycle + step = (step + 1) % (max_step + 1) + initial_data_received = False + + # Wait until end of control cycle + precise_wait(t_cycle_end) + + except KeyboardInterrupt: + # Handle Ctrl+C to safely disconnect + print("\n🛑 Ctrl+C detected. Disconnecting arm...") + for name in arm: + arm[name].disconnect() + print("✅ Arm disconnected. Exiting.") diff --git a/unitree_deploy/test/arm/z1/test_z1_env.py b/unitree_deploy/test/arm/z1/test_z1_env.py new file mode 100644 index 0000000..ca7bf9d --- /dev/null +++ b/unitree_deploy/test/arm/z1/test_z1_env.py @@ -0,0 +1,65 @@ +import math +import time + +import numpy as np +import pinocchio as pin + +from unitree_deploy.real_unitree_env import make_real_env +from unitree_deploy.utils.rerun_visualizer import RerunLogger, flatten_images, visualization_data +from unitree_deploy.utils.rich_logger import log_info +from unitree_deploy.utils.trajectory_generator import generate_rotation, sinusoidal_gripper_motion + +if __name__ == "__main__": + rerun_logger = RerunLogger() + env = make_real_env(robot_type="z1_realsense", dt=1 / 30) + env.connect() + + # Define initial target poses for left and right arms + arm_tf_target = pin.SE3(pin.Quaternion(1, 0, 0, 0), np.array([0.2, 0, 0.4])) + + # Motion parameters + rotation_speed = 0.01 # Rotation speed (rad per step) + control_dt = 1 / 30 # Control cycle duration (20ms) + step = 0 + max_step = 240 # Full motion cycle + + # Wait for user input to start the motion loop + user_input = input("Please enter the start signal (enter 's' to start the subsequent program): \n") + if user_input.lower() == "s": + try: + current_time = math.pi / 2 + idx = 0 # Initialize index for logging + while True: + # Define timing for the control cycle + t_cycle_end = time.monotonic() + control_dt + t_command_target = t_cycle_end + control_dt + + # Generate target rotation and translation + L_quat, R_quat, delta_l, delta_r = generate_rotation(step, rotation_speed, max_step) + arm_tf_target.translation += delta_l + # delta_r is not used in this context + arm_tf_target.rotation = L_quat.toRotationMatrix() + + # Solve inverse kinematics for the left arm + for arm_name in env.robot.arm: + arm_sol_q, arm_sol_tauff = env.robot.arm[arm_name].arm_ik(arm_tf_target.homogeneous) + + # Generate sinusoidal motion for the gripper + target_gripper = ( + sinusoidal_gripper_motion(period=4.0, amplitude=0.99, current_time=current_time) - 1 + ) # Adjust target_q by subtracting 1 + + target_arm = np.concatenate((arm_sol_q, target_gripper), axis=0) # Add a zero for the gripper + step_type, reward, _, observation = env.step(target_arm) + + idx += 1 + visualization_data(idx, flatten_images(observation), observation["qpos"], target_arm, rerun_logger) + + # Update step and reset after full cycle + current_time += control_dt + step = (step + 1) % (max_step + 1) + + except KeyboardInterrupt: + # Handle Ctrl+C to safely disconnect + log_info("\n🛑 Ctrl+C detected. Disconnecting arm...") + env.close() diff --git a/unitree_deploy/test/camera/test_image_client_camera.py b/unitree_deploy/test/camera/test_image_client_camera.py new file mode 100644 index 0000000..3c3f13a --- /dev/null +++ b/unitree_deploy/test/camera/test_image_client_camera.py @@ -0,0 +1,64 @@ +import time + +import cv2 +import numpy as np +import torch +from tqdm import tqdm + +from unitree_deploy.robot_devices.cameras.configs import ImageClientCameraConfig +from unitree_deploy.robot_devices.cameras.utils import make_cameras_from_configs +from unitree_deploy.utils.rich_logger import log_success + + +# ============================From configs============================ +def run_camera(): + def image_client_default_factory(): + return { + "imageclient": ImageClientCameraConfig( + head_camera_type="opencv", + head_camera_id_numbers=[4], + head_camera_image_shape=[480, 1280], # Head camera resolution + wrist_camera_type="opencv", + wrist_camera_id_numbers=[0, 2], + wrist_camera_image_shape=[480, 640], # Wrist camera resolution + aspect_ratio_threshold=2.0, + fps=30, + mock=False, + ), + } + + # =========================================== + + cameras = make_cameras_from_configs(image_client_default_factory()) + print(cameras) + for name in cameras: + cameras[name].connect() + log_success(f"Connecting {name} cameras.") + + for _ in tqdm(range(20), desc="Camera warming up"): + for name in cameras: + cameras[name].async_read() + time.sleep(1 / 30) + + while True: + images = {} + for name in cameras: + output = cameras[name].async_read() + if isinstance(output, dict): + for k, v in output.items(): + images[k] = torch.from_numpy(v) + else: + images[name] = torch.from_numpy(output) + + image_list = [np.stack([img.numpy()] * 3, axis=-1) if img.ndim == 2 else img.numpy() for img in images.values()] + + stacked_image = np.hstack(image_list) + cv2.imshow("Stacked Image", stacked_image) + + if (cv2.waitKey(1) & 0xFF) == ord("q"): + cv2.destroyAllWindows() + break + + +if __name__ == "__main__": + run_camera() diff --git a/unitree_deploy/test/camera/test_realsense_camera.py b/unitree_deploy/test/camera/test_realsense_camera.py new file mode 100644 index 0000000..8f87639 --- /dev/null +++ b/unitree_deploy/test/camera/test_realsense_camera.py @@ -0,0 +1,66 @@ +import time + +import cv2 +import numpy as np + +from unitree_deploy.robot_devices.cameras.configs import IntelRealSenseCameraConfig +from unitree_deploy.robot_devices.cameras.utils import make_cameras_from_configs +from unitree_deploy.utils.rich_logger import log_success + + +def run_camera(): + # =========================================== + def intelrealsense_camera_default_factory(): + return { + "cam_high": IntelRealSenseCameraConfig( + serial_number="044122071036", + fps=30, + width=640, + height=480, + ), + "cam_wrist": IntelRealSenseCameraConfig( + serial_number="419122270615", + fps=30, + width=640, + height=480, + ), + } + + # =========================================== + + cameras = make_cameras_from_configs(intelrealsense_camera_default_factory()) + for name in cameras: + cameras[name].connect() + log_success(f"Connecting {name} cameras.") + + for _ in range(20): + for name in cameras: + cameras[name].async_read() + time.sleep(1 / 30) + + while True: + images = [] + for name in cameras: + frame = cameras[name].async_read() + if frame is not None: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + cv2.putText(frame, name, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + images.append(frame) + + if images: + rows = [] + for i in range(0, len(images), 2): + row = np.hstack(images[i : i + 2]) + rows.append(row) + canvas = np.vstack(rows) + + cv2.imshow("All Cameras", canvas) + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + cv2.destroyAllWindows() + + +if __name__ == "__main__": + run_camera() diff --git a/unitree_deploy/test/camera/test_usb_camera.py b/unitree_deploy/test/camera/test_usb_camera.py new file mode 100644 index 0000000..e6c47ff --- /dev/null +++ b/unitree_deploy/test/camera/test_usb_camera.py @@ -0,0 +1,95 @@ +import time + +import cv2 +import numpy as np +import tyro +from tqdm import tqdm + +from unitree_deploy.robot_devices.cameras.configs import OpenCVCameraConfig +from unitree_deploy.robot_devices.cameras.utils import make_camera, make_cameras_from_configs +from unitree_deploy.utils.rich_logger import log_success + + +def usb_camera_default_factory(): + return { + "cam_high": OpenCVCameraConfig( + camera_index="/dev/video1", + fps=30, + width=640, + height=480, + ), + "cam_left_wrist": OpenCVCameraConfig( + camera_index="/dev/video3", + fps=30, + width=640, + height=480, + ), + "cam_right_wrist": OpenCVCameraConfig( + camera_index="/dev/video5", + fps=30, + width=640, + height=480, + ), + } + + +def run_cameras(camera_style: int = 0): + """ + Runs camera(s) based on the specified style. + + Args: + camera_style (int): + 0 - Single camera (OpenCV). + 1 - Multiple cameras from config. + """ + + if camera_style == 0: + # ========== Single camera ========== + camera_kwargs = {"camera_type": "opencv", "camera_index": "/dev/video5", "mock": False} + camera = make_camera(**camera_kwargs) + camera.connect() + log_success("Connecting camera.") + + while True: + color_image = camera.read() + color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB) + cv2.imshow("Camera", color_image) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + elif camera_style == 1: + # ========== Multi-camera from configs ========== + cameras = make_cameras_from_configs(usb_camera_default_factory()) + + for name in cameras: + cameras[name].connect() + log_success(f"Connecting {name} camera.") + + # Camera warm-up + for _ in tqdm(range(20), desc="Camera warming up"): + for name in cameras: + cameras[name].async_read() + time.sleep(1 / 30) + + while True: + images = {} + for name in cameras: + images[name] = cameras[name].async_read() + + image_list = [ + np.stack([img.numpy()] * 3, axis=-1) if img.ndim == 2 else img.numpy() for img in images.values() + ] + + stacked_image = np.hstack(image_list) + cv2.imshow("Multi-Camera View", stacked_image) + + if (cv2.waitKey(1) & 0xFF) == ord("q"): + cv2.destroyAllWindows() + break + + else: + raise ValueError(f"Unsupported camera_style: {camera_style}") + + +if __name__ == "__main__": + tyro.cli(run_cameras) diff --git a/unitree_deploy/test/endeffector/test_dex1.py b/unitree_deploy/test/endeffector/test_dex1.py new file mode 100644 index 0000000..6a6c0a1 --- /dev/null +++ b/unitree_deploy/test/endeffector/test_dex1.py @@ -0,0 +1,60 @@ +import time + +import tyro + +from unitree_deploy.robot_devices.endeffector.utils import ( + Dex1_GripperConfig, + make_endeffector_motors_buses_from_configs, +) +from unitree_deploy.robot_devices.robots_devices_utils import precise_wait +from unitree_deploy.utils.rich_logger import log_success +from unitree_deploy.utils.trajectory_generator import sinusoidal_single_gripper_motion + +period = 2.0 +motion_period = 2.0 +motion_amplitude = 0.99 + + +def gripper_default_factory(): + return { + "left": Dex1_GripperConfig( + unit_test=True, + motors={ + "kLeftGripper": [0, "z1_gripper-joint"], + }, + topic_gripper_state="rt/dex1/left/state", + topic_gripper_command="rt/dex1/left/cmd", + ), + "right": Dex1_GripperConfig( + unit_test=True, + motors={ + "kRightGripper": [1, "z1_gripper-joint"], + }, + topic_gripper_state="rt/dex1/right/state", + topic_gripper_command="rt/dex1/right/cmd", + ), + } + + +def run_gripper(): + control_dt = 1 / 30 + + log_success("Running gripper in style 1 (multi-bus from config)") + endeffectors = make_endeffector_motors_buses_from_configs(gripper_default_factory()) + + for name in endeffectors: + endeffectors[name].connect() + log_success(f"Connected endeffector '{name}'.") + + while True: + t_cycle_end = time.monotonic() + control_dt + target_q = sinusoidal_single_gripper_motion( + period=motion_period, amplitude=motion_amplitude, current_time=time.perf_counter() + ) + for name in endeffectors: + endeffectors[name].write_endeffector(q_target=target_q) + precise_wait(t_cycle_end) + + +if __name__ == "__main__": + tyro.cli(run_gripper) diff --git a/unitree_deploy/test/test_replay.py b/unitree_deploy/test/test_replay.py new file mode 100644 index 0000000..3d5eb68 --- /dev/null +++ b/unitree_deploy/test/test_replay.py @@ -0,0 +1,44 @@ +""" +python test/test_replay.py --repo-id unitreerobotics/G1_CameraPackaging_NewDataset --robot_type g1_dex1 +python test/test_replay.py --repo-id unitreerobotics/Z1_StackBox_Dataset --robot_type z1_realsense +python test/test_replay.py --repo-id unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2 --robot_type z1_dual_dex1_realsense +""" + +import tyro +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +from unitree_deploy.real_unitree_env import make_real_env +from unitree_deploy.utils.rerun_visualizer import RerunLogger, flatten_images, visualization_data +from unitree_deploy.utils.rich_logger import log_info + + +# Replay a specific episode from the LeRobot dataset using the real environment robot_type:(e.g., g1_dex1, z1_realsense, z1_dual_dex1_realsense) +def replay_lerobot_data(repo_id: str, robot_type: str, root: str | None = None, episode: int = 145): + dataset = LeRobotDataset(repo_id, root=root, episodes=[episode]) + actions = dataset.hf_dataset.select_columns("action") + init_pose_arm = actions[0]["action"].numpy()[:14] if robot_type == "g1" else actions[0]["action"].numpy() + rerun_logger = RerunLogger() + + env = make_real_env(robot_type=robot_type, dt=1 / 30, init_pose_arm=init_pose_arm) + env.connect() + + try: + # Wait for user input to start the motion loop + user_input = input("Please enter the start signal (enter 's' to start the subsequent program): \n") + if user_input.lower() == "s": + log_info("Replaying episode") + for idx in range(dataset.num_frames): + action = actions[idx]["action"].numpy() + if robot_type == "z1_realsense": + action[-1] = -action[-1] + step_type, reward, _, observation = env.step(action) + visualization_data(idx, flatten_images(observation), observation["qpos"], action, rerun_logger) + env.close() + except KeyboardInterrupt: + # Handle Ctrl+C to safely disconnect + log_info("\n🛑 Ctrl+C detected. Disconnecting arm...") + env.close() + + +if __name__ == "__main__": + tyro.cli(replay_lerobot_data) diff --git a/unitree_deploy/unitree_deploy/eval_dataset_env.py b/unitree_deploy/unitree_deploy/eval_dataset_env.py new file mode 100644 index 0000000..0bacb6d --- /dev/null +++ b/unitree_deploy/unitree_deploy/eval_dataset_env.py @@ -0,0 +1,105 @@ +import collections +import time + +import matplotlib.pyplot as plt +import numpy as np +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +from unitree_deploy.utils.rerun_visualizer import RerunLogger, visualization_data + + +def extract_observation(step: dict): + observation = {} + + for key, value in step.items(): + if key.startswith("observation.images."): + if isinstance(value, np.ndarray) and value.ndim == 3 and value.shape[-1] in [1, 3]: + value = np.transpose(value, (2, 0, 1)) + observation[key] = value + + elif key == "observation.state": + observation[key] = value + + return observation + + +class DatasetEvalEnv: + def __init__(self, repo_id: str, episode_index: int = 0, visualization: bool = True): + self.dataset = LeRobotDataset(repo_id=repo_id) + + self.visualization = visualization + if self.visualization: + self.rerun_logger = RerunLogger() + + self.from_idx = self.dataset.episode_data_index["from"][episode_index].item() + self.to_idx = self.dataset.episode_data_index["to"][episode_index].item() + self.step_idx = self.from_idx + + self.ground_truth_actions = [] + self.predicted_actions = [] + + def get_observation(self): + step = self.dataset[self.step_idx] + observation = extract_observation(step) + + state = step["observation.state"].numpy() + self.ground_truth_actions.append(step["action"].numpy()) + + if self.visualization: + visualization_data( + self.step_idx, + observation, + observation["observation.state"], + step["action"].numpy(), + self.rerun_logger, + ) + + images_observation = { + key: value.numpy() for key, value in observation.items() if key.startswith("observation.images.") + } + + obs = collections.OrderedDict() + obs["qpos"] = state + obs["images"] = images_observation + + self.step_idx += 1 + return obs + + def step(self, action): + self.predicted_actions.append(action) + + if self.step_idx == self.to_idx: + self._plot_results() + exit() + + def _plot_results(self): + ground_truth_actions = np.array(self.ground_truth_actions) + predicted_actions = np.array(self.predicted_actions) + + n_timesteps, n_dims = ground_truth_actions.shape + + fig, axes = plt.subplots(n_dims, 1, figsize=(12, 4 * n_dims), sharex=True) + fig.suptitle("Ground Truth vs Predicted Actions") + + for i in range(n_dims): + ax = axes[i] if n_dims > 1 else axes + ax.plot(ground_truth_actions[:, i], label="Ground Truth", color="blue") + ax.plot(predicted_actions[:, i], label="Predicted", color="red", linestyle="--") + ax.set_ylabel(f"Dim {i + 1}") + ax.legend() + + axes[-1].set_xlabel("Timestep") + plt.tight_layout() + plt.savefig("figure.png") + time.sleep(1) + + +def make_dataset_eval_env() -> DatasetEvalEnv: + return DatasetEvalEnv() + + +if __name__ == "__main__": + eval_dataset = DatasetEvalEnv(repo_id="unitreerobotics/G1_Brainco_PickApple_Dataset") + while True: + observation = eval_dataset.get_observation() + eval_dataset.step(observation["qpos"]) diff --git a/unitree_deploy/unitree_deploy/real_unitree_env.py b/unitree_deploy/unitree_deploy/real_unitree_env.py new file mode 100644 index 0000000..6b3a788 --- /dev/null +++ b/unitree_deploy/unitree_deploy/real_unitree_env.py @@ -0,0 +1,81 @@ +import collections +import time +from typing import List, Optional + +import cv2 +import dm_env +import numpy as np +import torch + +from unitree_deploy.robot.robot_utils import make_robot +from unitree_deploy.robot_devices.robots_devices_utils import precise_wait +from unitree_deploy.utils.rich_logger import log_success + + +class UnitreeEnv: + def __init__( + self, + robot_type: str = "z1_realsense", + dt: float = 1 / 30, + init_pose_arm: np.ndarray | List[float] | None = None, + ): + self.control_dt = dt + self.init_pose_arm = init_pose_arm + self.state: Optional[np.ndarray] = None + self.robot_type = robot_type + self.robot = make_robot(self.robot_type) + + def connect(self): + self.robot.connect() + + def _get_obs(self): + observation = self.robot.capture_observation() + + # Process images + image_dict = { + key.split("observation.images.")[-1]: cv2.cvtColor(value.numpy(), cv2.COLOR_BGR2RGB) + for key, value in observation.items() + if key.startswith("observation.images.") + } + # for image_key, image in image_dict.items(): + # cv2.imwrite(f"{image_key}.png", image) + + # Process state + self.state = observation["observation.state"].numpy() + + # Construct observation dictionary + obs = collections.OrderedDict( + qpos=self.state, + qvel=np.zeros_like(self.state), + effort=np.zeros_like(self.state), + images=image_dict, + ) + + return obs + + def get_observation(self, t=0): + step_type = dm_env.StepType.FIRST if t == 0 else dm_env.StepType.MID + return dm_env.TimeStep(step_type=step_type, reward=0, discount=None, observation=self._get_obs()) + + def step(self, action) -> dm_env.TimeStep: + t_cycle_end = time.monotonic() + self.control_dt + t_command_target = t_cycle_end + self.control_dt + self.robot.send_action(torch.from_numpy(action), t_command_target) + precise_wait(t_cycle_end) + + return dm_env.TimeStep( + step_type=dm_env.StepType.MID, + reward=0, + discount=None, + observation=self._get_obs(), + ) + + def close(self) -> None: + self.robot.disconnect() + log_success("Robot disconnected successfully! 🎉") + + +def make_real_env( + robot_type: str, dt: float | None, init_pose_arm: np.ndarray | List[float] | None = None +) -> UnitreeEnv: + return UnitreeEnv(robot_type, dt, init_pose_arm) diff --git a/unitree_deploy/unitree_deploy/robot/robot.py b/unitree_deploy/unitree_deploy/robot/robot.py new file mode 100644 index 0000000..4fa9a5b --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot/robot.py @@ -0,0 +1,147 @@ +import time + +import torch + +from unitree_deploy.robot.robot_configs import UnitreeRobotConfig +from unitree_deploy.robot_devices.arm.utils import make_arm_motors_buses_from_configs +from unitree_deploy.robot_devices.cameras.utils import make_cameras_from_configs +from unitree_deploy.robot_devices.endeffector.utils import ( + make_endeffector_motors_buses_from_configs, +) +from unitree_deploy.utils.rich_logger import log_success + + +class UnitreeRobot: + def __init__( + self, + config: UnitreeRobotConfig, + ): + self.config = config + self.robot_type = self.config.type + self.cameras = make_cameras_from_configs(self.config.cameras) + self.arm = make_arm_motors_buses_from_configs(self.config.arm) + self.endeffector = make_endeffector_motors_buses_from_configs(self.config.endeffector) + + self.initial_data_received = True + + def connect(self): + if not self.arm and self.endeffector and not self.cameras: + raise ValueError( + "UnitreeRobot doesn't have any device to connect. See example of usage in docstring of the class." + ) + # Connect the cameras + for name in self.cameras: + self.cameras[name].connect() + log_success(f"Connecting {name} cameras.") + + for _ in range(20): + for name in self.cameras: + self.cameras[name].async_read() + time.sleep(1 / 30) + + for name in self.arm: + self.arm[name].connect() + log_success(f"Connecting {name} arm.") + + for name in self.endeffector: + self.endeffector[name].connect() + log_success(f"Connecting {name} endeffector.") + + time.sleep(2) + log_success("All Device Connect Success!!!.✅") + + def capture_observation(self): + """The returned observations do not have a batch dimension.""" + + # Create state by concatenating follower current position + state = [] + arm_state_list = [] + endeffector_state_list = [] + for arm_name in self.arm: + arm_state = self.arm[arm_name].read_current_arm_q() + arm_state_list.append(torch.from_numpy(arm_state)) + + for endeffector_name in self.endeffector: + endeffector_state = self.endeffector[endeffector_name].read_current_endeffector_q() + endeffector_state_list.append(torch.from_numpy(endeffector_state)) + + state = ( + torch.cat(arm_state_list + endeffector_state_list, dim=0) + if arm_state_list or endeffector_state_list + else torch.tensor([]) + ) + + # Capture images from cameras + images = {} + for name in self.cameras: + output = self.cameras[name].async_read() + if isinstance(output, dict): + images.update({k: torch.from_numpy(v) for k, v in output.items()}) + else: + images[name] = torch.from_numpy(output) + + # Populate output dictionnaries and format to pytorch + obs_dict = {} + obs_dict["observation.state"] = state + for name, value in images.items(): + obs_dict[f"observation.images.{name}"] = value + return obs_dict + + def send_action(self, action: torch.Tensor, t_command_target: float | None = None) -> torch.Tensor: + from_idx_arm = 0 + to_idx_arm = 0 + action_sent_arm = [] + cmd_target = "drive_to_waypoint" if self.initial_data_received else "schedule_waypoint" + + for arm_name in self.arm: + to_idx_arm += len(self.arm[arm_name].motor_names) + action_arm = action[from_idx_arm:to_idx_arm].numpy() + from_idx_arm = to_idx_arm + + action_sent_arm.append(torch.from_numpy(action_arm)) + + self.arm[arm_name].write_arm( + action_arm, + time_target=t_command_target - time.monotonic() + time.perf_counter(), + cmd_target=cmd_target, + ) + + from_idx_endeffector = to_idx_arm + to_idx_endeffector = to_idx_arm + + action_endeffector_set = [] + for endeffector_name in self.endeffector: + to_idx_endeffector += len(self.endeffector[endeffector_name].motor_names) + action_endeffector = action[from_idx_endeffector:to_idx_endeffector].numpy() + from_idx_endeffector = to_idx_endeffector + + action_endeffector_set.append(torch.from_numpy(action_endeffector)) + + self.endeffector[endeffector_name].write_endeffector( + action_endeffector, + time_target=t_command_target - time.monotonic() + time.perf_counter(), + cmd_target=cmd_target, + ) + + self.initial_data_received = False + + return torch.cat(action_sent_arm + action_endeffector_set, dim=0) + + def disconnect(self): + # disconnect the arms + for name in self.arm: + self.arm[name].disconnect() + log_success(f"disconnect {name} arm.") + + for name in self.endeffector: + self.endeffector[name].disconnect() + log_success(f"disconnect {name} endeffector.") + + # disconnect the cameras + for name in self.cameras: + self.cameras[name].disconnect() + log_success(f"disconnect {name} cameras.") + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/unitree_deploy/unitree_deploy/robot/robot_configs.py b/unitree_deploy/unitree_deploy/robot/robot_configs.py new file mode 100644 index 0000000..acc4576 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot/robot_configs.py @@ -0,0 +1,270 @@ +import abc +from dataclasses import dataclass, field + +import draccus +import numpy as np + +from unitree_deploy.robot_devices.arm.configs import ( + ArmConfig, + G1ArmConfig, + Z1ArmConfig, + Z1DualArmConfig, +) +from unitree_deploy.robot_devices.cameras.configs import ( + CameraConfig, + ImageClientCameraConfig, + IntelRealSenseCameraConfig, + OpenCVCameraConfig, +) +from unitree_deploy.robot_devices.endeffector.configs import ( + Dex1_GripperConfig, + EndEffectorConfig, +) + +# ======================== arm motors ================================= +# name: (index, model) +g1_motors = { + "kLeftShoulderPitch": [0, "g1-joint"], + "kLeftShoulderRoll": [1, "g1-joint"], + "kLeftShoulderYaw": [2, "g1-joint"], + "kLeftElbow": [3, "g1-joint"], + "kLeftWristRoll": [4, "g1-joint"], + "kLeftWristPitch": [5, "g1-joint"], + "kLeftWristyaw": [6, "g1-joint"], + "kRightShoulderPitch": [7, "g1-joint"], + "kRightShoulderRoll": [8, "g1-joint"], + "kRightShoulderYaw": [9, "g1-joint"], + "kRightElbow": [10, "g1-joint"], + "kRightWristRoll": [11, "g1-joint"], + "kRightWristPitch": [12, "g1-joint"], + "kRightWristYaw": [13, "g1-joint"], +} + +z1_motors = { + "kWaist": [0, "z1-joint"], + "kShoulder": [1, "z1-joint"], + "kElbow": [2, "z1-joint"], + "kForearmRoll": [3, "z1-joint"], + "kWristAngle": [4, "z1-joint"], + "kWristRotate": [5, "z1-joint"], + "kGripper": [6, "z1-joint"], +} + +z1_dual_motors = { + "kLeftWaist": [0, "z1-joint"], + "kLeftShoulder": [1, "z1-joint"], + "kLeftElbow": [2, "z1-joint"], + "kLeftForearmRoll": [3, "z1-joint"], + "kLeftWristAngle": [4, "z1-joint"], + "kLeftWristRotate": [5, "z1-joint"], + "kRightWaist": [7, "z1-joint"], + "kRightShoulder": [8, "z1-joint"], + "kRightElbow": [9, "z1-joint"], + "kRightForearmRoll": [10, "z1-joint"], + "kRightWristAngle": [11, "z1-joint"], + "kRightWristRotate": [12, "z1-joint"], +} +# ========================================================= + + +# ======================== camera ================================= + + +def z1_intelrealsense_camera_default_factory(): + return { + "cam_high": IntelRealSenseCameraConfig( + serial_number="044122071036", + fps=30, + width=640, + height=480, + ), + # "cam_wrist": IntelRealSenseCameraConfig( + # serial_number="419122270615", + # fps=30, + # width=640, + # height=480, + # ), + } + + +def z1_dual_intelrealsense_camera_default_factory(): + return { + # "cam_left_wrist": IntelRealSenseCameraConfig( + # serial_number="218722271166", + # fps=30, + # width=640, + # height=480, + # ), + # "cam_right_wrist": IntelRealSenseCameraConfig( + # serial_number="419122270677", + # fps=30, + # width=640, + # height=480, + # ), + "cam_high": IntelRealSenseCameraConfig( + serial_number="947522071393", + fps=30, + width=640, + height=480, + ), + } + + +def g1_image_client_default_factory(): + return { + "imageclient": ImageClientCameraConfig( + head_camera_type="opencv", + head_camera_id_numbers=[4], + head_camera_image_shape=[480, 1280], # Head camera resolution + wrist_camera_type="opencv", + wrist_camera_id_numbers=[0, 2], + wrist_camera_image_shape=[480, 640], # Wrist camera resolution + aspect_ratio_threshold=2.0, + fps=30, + mock=False, + ), + } + + +def usb_camera_default_factory(): + return { + "cam_high": OpenCVCameraConfig( + camera_index="/dev/video1", + fps=30, + width=640, + height=480, + ), + "cam_left_wrist": OpenCVCameraConfig( + camera_index="/dev/video5", + fps=30, + width=640, + height=480, + ), + "cam_right_wrist": OpenCVCameraConfig( + camera_index="/dev/video3", + fps=30, + width=640, + height=480, + ), + } + + +# ========================================================= + + +# ======================== endeffector ================================= + + +def dex1_default_factory(): + return { + "left": Dex1_GripperConfig( + unit_test=True, + motors={ + "kLeftGripper": [0, "z1_gripper-joint"], + }, + topic_gripper_state="rt/dex1/left/state", + topic_gripper_command="rt/dex1/left/cmd", + ), + "right": Dex1_GripperConfig( + unit_test=True, + motors={ + "kRightGripper": [1, "z1_gripper-joint"], + }, + topic_gripper_state="rt/dex1/right/state", + topic_gripper_command="rt/dex1/right/cmd", + ), + } + + +# ========================================================= + +# ======================== arm ================================= + + +def z1_arm_default_factory(init_pose=None): + return { + "z1": Z1ArmConfig( + init_pose=np.zeros(7) if init_pose is None else init_pose, + motors=z1_motors, + ), + } + + +def z1_dual_arm_single_config_factory(init_pose=None): + return { + "z1_dual": Z1DualArmConfig( + left_robot_ip="127.0.0.1", + left_robot_port1=8073, + left_robot_port2=8074, + right_robot_ip="127.0.0.1", + right_robot_port1=8071, + right_robot_port2=8072, + init_pose_left=np.zeros(6) if init_pose is None else init_pose[:6], + init_pose_right=np.zeros(6) if init_pose is None else init_pose[6:], + control_dt=1 / 250.0, + motors=z1_dual_motors, + ), + } + + +def g1_dual_arm_default_factory(init_pose=None): + return { + "g1": G1ArmConfig( + init_pose=np.zeros(14) if init_pose is None else init_pose, + motors=g1_motors, + mock=False, + ), + } + + +# ========================================================= + + +# robot_type: arm devies _ endeffector devies _ camera devies +@dataclass +class RobotConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@dataclass +class UnitreeRobotConfig(RobotConfig): + cameras: dict[str, CameraConfig] = field(default_factory=lambda: {}) + arm: dict[str, ArmConfig] = field(default_factory=lambda: {}) + endeffector: dict[str, EndEffectorConfig] = field(default_factory=lambda: {}) + + +# =============================== Single-arm:z1, Camera:Realsense ======================================== +@RobotConfig.register_subclass("z1_realsense") +@dataclass +class Z1_Realsense_RobotConfig(UnitreeRobotConfig): + cameras: dict[str, CameraConfig] = field(default_factory=z1_intelrealsense_camera_default_factory) + arm: dict[str, ArmConfig] = field(default_factory=z1_arm_default_factory) + + +# =============================== Dual-arm:z1, Endeffector:dex1, Camera:Realsense ======================================== +@RobotConfig.register_subclass("z1_dual_dex1_realsense") +@dataclass +class Z1dual_Dex1_Realsense_RobotConfig(UnitreeRobotConfig): + cameras: dict[str, CameraConfig] = field(default_factory=z1_dual_intelrealsense_camera_default_factory) + arm: dict[str, ArmConfig] = field(default_factory=z1_dual_arm_single_config_factory) + endeffector: dict[str, EndEffectorConfig] = field(default_factory=dex1_default_factory) + + +# =============================== Dual-arm:z1, Endeffector:dex1, Camera:Realsense ======================================== +@RobotConfig.register_subclass("z1_dual_dex1_opencv") +@dataclass +class Z1dual_Dex1_Opencv_RobotConfig(UnitreeRobotConfig): + cameras: dict[str, CameraConfig] = field(default_factory=usb_camera_default_factory) + arm: dict[str, ArmConfig] = field(default_factory=z1_dual_arm_single_config_factory) + endeffector: dict[str, EndEffectorConfig] = field(default_factory=dex1_default_factory) + + +# =============================== Arm:g1, Endeffector:dex1, Camera:imageclint ======================================== +@RobotConfig.register_subclass("g1_dex1") +@dataclass +class G1_Dex1_Imageclint_RobotConfig(UnitreeRobotConfig): + cameras: dict[str, CameraConfig] = field(default_factory=g1_image_client_default_factory) + arm: dict[str, ArmConfig] = field(default_factory=g1_dual_arm_default_factory) + endeffector: dict[str, EndEffectorConfig] = field(default_factory=dex1_default_factory) diff --git a/unitree_deploy/unitree_deploy/robot/robot_utils.py b/unitree_deploy/unitree_deploy/robot/robot_utils.py new file mode 100644 index 0000000..746a5e5 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot/robot_utils.py @@ -0,0 +1,47 @@ +from typing import Protocol + +from unitree_deploy.robot.robot_configs import ( + G1_Dex1_Imageclint_RobotConfig, + RobotConfig, + Z1_Realsense_RobotConfig, + Z1dual_Dex1_Opencv_RobotConfig, + Z1dual_Dex1_Realsense_RobotConfig, +) + + +def get_arm_id(name, arm_type): + return f"{name}_{arm_type}" + + +class Robot(Protocol): + robot_type: str + features: dict + + def connect(self): ... + def capture_observation(self): ... + def send_action(self, action): ... + def disconnect(self): ... + + +def make_robot_config(robot_type: str, **kwargs) -> RobotConfig: + if robot_type == "z1_realsense": + return Z1_Realsense_RobotConfig(**kwargs) + elif robot_type == "z1_dual_dex1_realsense": + return Z1dual_Dex1_Realsense_RobotConfig(**kwargs) + elif robot_type == "z1_dual_dex1_opencv": + return Z1dual_Dex1_Opencv_RobotConfig(**kwargs) + elif robot_type == "g1_dex1": + return G1_Dex1_Imageclint_RobotConfig(**kwargs) + else: + raise ValueError(f"Robot type '{robot_type}' is not available.") + + +def make_robot_from_config(config: RobotConfig): + from unitree_deploy.robot.robot import UnitreeRobot + + return UnitreeRobot(config) + + +def make_robot(robot_type: str, **kwargs) -> Robot: + config = make_robot_config(robot_type, **kwargs) + return make_robot_from_config(config) diff --git a/unitree_deploy/unitree_deploy/robot_devices/arm/arm_indexs.py b/unitree_deploy/unitree_deploy/robot_devices/arm/arm_indexs.py new file mode 100644 index 0000000..adcfdbd --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/arm/arm_indexs.py @@ -0,0 +1,119 @@ +# noqa: N815 +from enum import IntEnum + + +# ==================g1======================== +class G1_29_JointArmIndex(IntEnum): + # Left arm + kLeftShoulderPitch = 15 + kLeftShoulderRoll = 16 + kLeftShoulderYaw = 17 + kLeftElbow = 18 + kLeftWristRoll = 19 + kLeftWristPitch = 20 + kLeftWristyaw = 21 + + # Right arm + kRightShoulderPitch = 22 + kRightShoulderRoll = 23 + kRightShoulderYaw = 24 + kRightElbow = 25 + kRightWristRoll = 26 + kRightWristPitch = 27 + kRightWristYaw = 28 + + +class G1_29_JointIndex(IntEnum): + # Left leg + kLeftHipPitch = 0 + kLeftHipRoll = 1 + kLeftHipYaw = 2 + kLeftKnee = 3 + kLeftAnklePitch = 4 + kLeftAnkleRoll = 5 + + # Right leg + kRightHipPitch = 6 + kRightHipRoll = 7 + kRightHipYaw = 8 + kRightKnee = 9 + kRightAnklePitch = 10 + kRightAnkleRoll = 11 + + kWaistYaw = 12 + kWaistRoll = 13 + kWaistPitch = 14 + + # Left arm + kLeftShoulderPitch = 15 + kLeftShoulderRoll = 16 + kLeftShoulderYaw = 17 + kLeftElbow = 18 + kLeftWristRoll = 19 + kLeftWristPitch = 20 + kLeftWristyaw = 21 + + # Right arm + kRightShoulderPitch = 22 + kRightShoulderRoll = 23 + kRightShoulderYaw = 24 + kRightElbow = 25 + kRightWristRoll = 26 + kRightWristPitch = 27 + kRightWristYaw = 28 + + # not used + kNotUsedJoint0 = 29 + kNotUsedJoint1 = 30 + kNotUsedJoint2 = 31 + kNotUsedJoint3 = 32 + kNotUsedJoint4 = 33 + kNotUsedJoint5 = 34 + + +# ========================================== + + +# ==================z1======================== +class Z1ArmJointIndex(IntEnum): + WAIST = 0 + SHOULDER = 1 + ELBOW = 2 + FOREARM_ROLL = 3 + WRIST_ANGLE = 4 + WRIST_ROTATE = 5 + + +class Z1_12_JointArmIndex(IntEnum): + # Left arm + kLeftWaist = 0 + kLeftShoulder = 1 + kLeftElbow = 2 + kLeftForearmRoll = 3 + kLeftWristAngle = 4 + kLeftWristRotate = 5 + + # Right arm + kRightWaist = 6 + kRightShoulder = 7 + kRightElbow = 8 + kRightForearmRoll = 9 + kRightWristAngle = 10 + kRightWristRotate = 11 + + +class Z1GripperArmJointIndex(IntEnum): + WAIST = 0 + SHOULDER = 1 + ELBOW = 2 + FOREARM_ROLL = 3 + WRIST_ANGLE = 4 + WRIST_ROTATE = 5 + GRIPPER = 6 + + +class Gripper_Sigle_JointIndex(IntEnum): + kGripper = 0 + + +# ========================================== diff --git a/unitree_deploy/unitree_deploy/robot_devices/arm/configs.py b/unitree_deploy/unitree_deploy/robot_devices/arm/configs.py new file mode 100644 index 0000000..7cc88dd --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/arm/configs.py @@ -0,0 +1,82 @@ +import abc +from dataclasses import dataclass + +import draccus +import numpy as np + + +@dataclass +class ArmConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@ArmConfig.register_subclass("z1") +@dataclass +class Z1ArmConfig(ArmConfig): + motors: dict[str, tuple[int, str]] + + init_pose: list = None + unit_test: bool = False + control_dt: float = 1 / 500.0 + + robot_kp: np.ndarray = np.array([4, 6, 6, 6, 6, 6]) + robot_kd: np.ndarray = np.array([350, 300, 300, 200, 200, 200]) + max_pos_speed: float = 180 * (np.pi / 180) * 2 + log_level: str | int = "ERROR" + + def __post_init__(self): + if self.control_dt < 0.002: + raise ValueError(f"`control_dt` must > 1/500 (got {self.control_dt})") + + +@ArmConfig.register_subclass("z1_dual") +@dataclass +class Z1DualArmConfig(ArmConfig): + left_robot_ip: str + left_robot_port1: int + left_robot_port2: int + right_robot_ip: str + right_robot_port1: int + right_robot_port2: int + motors: dict[str, tuple[int, str]] + + robot_kp: np.ndarray = np.array([4, 6, 6, 6, 6, 6]) + robot_kd: np.ndarray = np.array([350, 300, 300, 200, 200, 200]) + mock: bool = False + unit_test: bool = False + init_pose_left: list | None = None + init_pose_right: list | None = None + max_pos_speed: float = 180 * (np.pi / 180) * 2 + control_dt: float = 1 / 500.0 + + def __post_init__(self): + if self.control_dt < 0.002: + raise ValueError(f"`control_dt` must > 1/500 (got {self.control_dt})") + + +@ArmConfig.register_subclass("g1") +@dataclass +class G1ArmConfig(ArmConfig): + motors: dict[str, tuple[int, str]] + mock: bool = False + unit_test: bool = False + init_pose: np.ndarray | list = np.zeros(14) + + control_dt: float = 1 / 500.0 + max_pos_speed: float = 180 * (np.pi / 180) * 2 + + topic_low_command: str = "rt/lowcmd" + topic_low_state: str = "rt/lowstate" + + kp_high: float = 300.0 + kd_high: float = 3.0 + kp_low: float = 80.0 # 140.0 + kd_low: float = 3.0 # 3.0 + kp_wrist: float = 40.0 + kd_wrist: float = 1.5 + + def __post_init__(self): + if self.control_dt < 0.002: + raise ValueError(f"`control_dt` must > 1/500 (got {self.control_dt})") diff --git a/unitree_deploy/unitree_deploy/robot_devices/arm/g1_arm.py b/unitree_deploy/unitree_deploy/robot_devices/arm/g1_arm.py new file mode 100644 index 0000000..44d0f3b --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/arm/g1_arm.py @@ -0,0 +1,385 @@ +import threading +import time +from typing import Callable + +import numpy as np +from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber # dds +from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_ +from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_, LowState_ # idl +from unitree_sdk2py.utils.crc import CRC + +from unitree_deploy.robot_devices.arm.arm_indexs import G1_29_JointArmIndex, G1_29_JointIndex +from unitree_deploy.robot_devices.arm.configs import G1ArmConfig +from unitree_deploy.robot_devices.arm.g1_arm_ik import G1_29_ArmIK +from unitree_deploy.robot_devices.robots_devices_utils import ( + DataBuffer, + MotorState, + Robot_Num_Motors, + RobotDeviceAlreadyConnectedError, +) +from unitree_deploy.utils.joint_trajcetory_inter import JointTrajectoryInterpolator +from unitree_deploy.utils.rich_logger import log_error, log_info, log_success, log_warning +from unitree_deploy.utils.run_simulation import MujicoSimulation, get_mujoco_sim_config + + +class G1_29_LowState: + def __init__(self): + self.motor_state = [MotorState() for _ in range(Robot_Num_Motors.G1_29_Num_Motors)] + + +class G1_29_ArmController: + def __init__(self, config: G1ArmConfig): + self.motors = config.motors + self.mock = config.mock + self.unit_test = config.unit_test + self.init_pose = config.init_pose + self.control_dt = config.control_dt + + self.max_pos_speed = config.max_pos_speed + + self.topic_low_command = config.topic_low_command + self.topic_low_state = config.topic_low_state + + self.kp_high = config.kp_high + self.kd_high = config.kd_high + self.kp_low = config.kp_low + self.kd_low = config.kd_low + self.kp_wrist = config.kp_wrist + self.kd_wrist = config.kd_wrist + + self.all_motor_q = None + self.q_target = np.zeros(14) + self.dq_target = np.zeros(14) + self.tauff_target = np.zeros(14) + self.time_target = time.monotonic() + self.arm_cmd = "schedule_waypoint" + + self.lowstate_buffer = DataBuffer() + self.g1_arm_ik = G1_29_ArmIK(unit_test=self.unit_test, visualization=False) + + self.stop_event = threading.Event() + self.ctrl_lock = threading.Lock() + + self.is_connected = False + + @property + def motor_names(self) -> list[str]: + return list(self.motors.keys()) + + @property + def motor_models(self) -> list[str]: + return [model for _, model in self.motors.values()] + + @property + def motor_indices(self) -> list[int]: + return [idx for idx, _ in self.motors.values()] + + def _start_daemon_thread(self, target_fn: Callable[[], None], name: str | None = None) -> threading.Thread: + thread = threading.Thread(target=target_fn, name=name) + thread.daemon = True + thread.start() + return thread + + def connect(self): + try: + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + "G1_Arm is already connected. Do not call `robot.connect()` twice." + ) + if self.mock: + config = get_mujoco_sim_config(robot_type="g1") + self.g1 = MujicoSimulation(config) + time.sleep(1) + else: + # initialize lowcmd publisher and lowstate subscriber + ChannelFactoryInitialize(0) + self.lowcmd_publisher = ChannelPublisher(self.topic_low_command, LowCmd_) + self.lowcmd_publisher.Init() + self.lowstate_subscriber = ChannelSubscriber(self.topic_low_state, LowState_) + self.lowstate_subscriber.Init() + + # initialize subscribe thread + self.subscribe_thread = self._start_daemon_thread( + self._subscribe_motor_state, name="g1._subscribe_motor_state" + ) + + while not self.lowstate_buffer.get_data(): + time.sleep(0.01) + log_warning("[G1_29_ArmController] Waiting to subscribe dds...") + + if not self.mock: + # initialize hg's lowcmd msg + self.crc = CRC() + self.msg = unitree_hg_msg_dds__LowCmd_() + self.msg.mode_pr = 0 + self.msg.mode_machine = self._read_mode_machine() + + self.all_motor_q = self._read_current_motor_q() + log_info(f"Current all body motor state q:\n{self.all_motor_q} \n") + log_info(f"Current two arms motor state q:\n{self.read_current_arm_q()}\n") + log_info("Lock all joints except two arms...\n") + + arm_indices = {member.value for member in G1_29_JointArmIndex} + for id in G1_29_JointIndex: + self.msg.motor_cmd[id].mode = 1 + if id.value in arm_indices: + if self._is_wrist_motor(id): + self.msg.motor_cmd[id].kp = self.kp_wrist + self.msg.motor_cmd[id].kd = self.kd_wrist + else: + self.msg.motor_cmd[id].kp = self.kp_low + self.msg.motor_cmd[id].kd = self.kd_low + else: + if self._is_weak_motor(id): + self.msg.motor_cmd[id].kp = self.kp_low + self.msg.motor_cmd[id].kd = self.kd_low + else: + self.msg.motor_cmd[id].kp = self.kp_high + self.msg.motor_cmd[id].kd = self.kd_high + self.msg.motor_cmd[id].q = self.all_motor_q[id] + log_info("Lock OK!\n") + + # initialize publish thread + self.publish_thread = self._start_daemon_thread(self._ctrl_motor_state, name="g1._ctrl_motor_state") + self.is_connected = True + + except Exception as e: + self.disconnect() + log_error(f"❌ Error in G1_29_ArmController.connect: {e}") + + def _subscribe_motor_state(self): + try: + while not self.stop_event.is_set(): + lowstate = G1_29_LowState() + if self.mock: + if self.g1.get_current_positions() is not None and len(self.g1.get_current_positions()) != 0: + for motor_id in range(Robot_Num_Motors.G1_29_Num_Motors): + lowstate.motor_state[motor_id].q = self.g1.get_current_positions()[motor_id] + lowstate.motor_state[motor_id].dq = 0.0 + else: + print("[WARN] get_current_positions() failed: queue is empty.") + else: + msg = self.lowstate_subscriber.Read() + if msg is not None: + for id in range(Robot_Num_Motors.G1_29_Num_Motors): + lowstate.motor_state[id].q = msg.motor_state[id].q + lowstate.motor_state[id].dq = msg.motor_state[id].dq + self.lowstate_buffer.set_data(lowstate) + time.sleep(self.control_dt) + except Exception as e: + self.disconnect() + log_error(f"❌ Error in G1_29_ArmController._subscribe_motor_state: {e}") + + def _update_g1_arm( + self, + arm_q_target: np.ndarray, + arm_dq_target: np.ndarray | None = None, + arm_tauff_target: np.ndarray | None = None, + ): + if self.mock: + self.g1.set_positions(arm_q_target) + else: + for idx, id in enumerate(G1_29_JointArmIndex): + self.msg.motor_cmd[id].q = arm_q_target[idx] + self.msg.motor_cmd[id].dq = arm_dq_target[idx] + self.msg.motor_cmd[id].tau = arm_tauff_target[idx] + + self.msg.crc = self.crc.Crc(self.msg) + self.lowcmd_publisher.Write(self.msg) + + def _drive_to_waypoint(self, target_pose: np.ndarray, t_insert_time: float): + curr_time = time.monotonic() + self.control_dt + t_insert = curr_time + t_insert_time + self.pose_interp = self.pose_interp.drive_to_waypoint( + pose=target_pose, + time=t_insert, + curr_time=curr_time, + max_pos_speed=self.max_pos_speed, + ) + + while time.monotonic() < t_insert: + start_time = time.perf_counter() + + cliped_arm_q_target = self.pose_interp(time.monotonic()) + self._update_g1_arm(cliped_arm_q_target, self.dq_target, self.tauff_target) + + time.sleep(max(0, (self.control_dt - (time.perf_counter() - start_time)))) + + def _schedule_waypoint( + self, + arm_q_target: np.ndarray, + arm_time_target: float, + t_now: float, + start_time: float, + last_waypoint_time: float, + arm_tauff_target: np.ndarray | None = None, + ) -> float: + target_time = time.monotonic() - time.perf_counter() + arm_time_target + curr_time = t_now + self.control_dt + target_time = max(target_time, curr_time + self.control_dt) + + self.pose_interp = self.pose_interp.schedule_waypoint( + pose=arm_q_target, + time=target_time, + max_pos_speed=self.max_pos_speed, + curr_time=curr_time, + last_waypoint_time=last_waypoint_time, + ) + last_waypoint_time = target_time + + cliped_arm_q_target = self.pose_interp(t_now) + self._update_g1_arm(cliped_arm_q_target, self.dq_target, arm_tauff_target) + + time.sleep(max(0, (self.control_dt - (time.perf_counter() - start_time)))) + + def _ctrl_motor_state(self): + # wait dds init done !!! + time.sleep(2) + + self.pose_interp = JointTrajectoryInterpolator( + times=[time.monotonic()], joint_positions=[self.read_current_arm_q()] + ) + self.go_start() + + arm_q_target = self.read_current_arm_q() + arm_tauff_target = self.tauff_target + arm_time_target = time.monotonic() + arm_cmd = "schedule_waypoint" + + last_waypoint_time = time.monotonic() + while not self.stop_event.is_set(): + start_time = time.perf_counter() + t_now = time.monotonic() + + with self.ctrl_lock: + arm_q_target = self.q_target + arm_tauff_target = self.tauff_target + arm_time_target = self.time_target + arm_cmd = self.arm_cmd + + if arm_cmd == "drive_to_waypoint": + self._drive_to_waypoint(target_pose=arm_q_target, t_insert_time=0.8) + + elif arm_cmd == "schedule_waypoint": + self._schedule_waypoint( + arm_q_target=arm_q_target, + arm_time_target=arm_time_target, + t_now=t_now, + start_time=start_time, + last_waypoint_time=last_waypoint_time, + arm_tauff_target=arm_tauff_target, + ) + + # target_time = time.monotonic() - time.perf_counter() + arm_time_target + # curr_time = t_now + self.control_dt + + # target_time = max(target_time, curr_time + self.control_dt) + + # self.pose_interp = self.pose_interp.schedule_waypoint( + # pose=arm_q_target, + # time=target_time, + # max_pos_speed=self.max_pos_speed, + # curr_time=curr_time, + # last_waypoint_time=last_waypoint_time + # ) + # last_waypoint_time = target_time + + # cliped_arm_q_target = self.pose_interp(t_now) + # self._update_g1_arm(cliped_arm_q_target, self.dq_target, arm_tauff_target) + + # time.sleep(max(0, (self.control_dt - (time.perf_counter() - start_time)))) + + def _read_mode_machine(self): + """Return current dds mode machine.""" + return self.lowstate_subscriber.Read().mode_machine + + def _read_current_motor_q(self) -> np.ndarray | None: + """Return current state q of all body motors.""" + return np.array([self.lowstate_buffer.get_data().motor_state[id].q for id in G1_29_JointIndex]) + + def read_current_arm_q(self) -> np.ndarray | None: + """Return current state q of the left and right arm motors.""" + return np.array([self.lowstate_buffer.get_data().motor_state[id].q for id in G1_29_JointArmIndex]) + + def read_current_arm_dq(self) -> np.ndarray | None: + """Return current state dq of the left and right arm motors.""" + return np.array([self.lowstate_buffer.get_data().motor_state[id].dq for id in G1_29_JointArmIndex]) + + def write_arm( + self, + q_target: list[float] | np.ndarray, + tauff_target: list[float] | np.ndarray = None, + time_target: float | None = None, + cmd_target: str | None = None, + ): + """Set control target values q & tau of the left and right arm motors.""" + with self.ctrl_lock: + self.q_target = q_target + self.tauff_target = tauff_target if tauff_target is not None else self.arm_tau(self.q_target) + self.time_target = time_target + self.arm_cmd = cmd_target + + def arm_ik( + self, l_ee_target: list[float] | np.ndarray, r_ee_target: list[float] | np.ndarray + ) -> tuple[np.ndarray, np.ndarray] | None: + return self.g1_arm_ik.solve_ik(l_ee_target, r_ee_target, self.read_current_arm_q(), self.read_current_arm_dq()) + + def arm_tau( + self, current_arm_q: np.ndarray | None = None, current_arm_dq: np.ndarray | None = None + ) -> np.ndarray | None: + return self.g1_arm_ik.solve_tau(current_arm_q, current_arm_dq) + + def arm_fk(self, q: np.ndarray | None = None) -> np.ndarray | None: + pass + + def go_start(self): + self._drive_to_waypoint(target_pose=self.init_pose, t_insert_time=2.0) + log_success("[G1_29_ArmController] Go Start OK!\n") + + def go_home(self): + if self.mock: + self.stop_event.set() + # self.subscribe_thread.join() + # self.publish_thread.join() + + time.sleep(1) + # self.g1.stop() + + else: + self.stop_event.set() + self.publish_thread.join() + + self._drive_to_waypoint(target_pose=self.init_pose, t_insert_time=2.0) + log_success("[G1_29_ArmController] Go Home OK!\n") + + def disconnect(self): + self.is_connected = False + self.go_home() + + def _is_weak_motor(self, motor_index): + weak_motors = [ + G1_29_JointIndex.kLeftAnklePitch.value, + G1_29_JointIndex.kRightAnklePitch.value, + # Left arm + G1_29_JointIndex.kLeftShoulderPitch.value, + G1_29_JointIndex.kLeftShoulderRoll.value, + G1_29_JointIndex.kLeftShoulderYaw.value, + G1_29_JointIndex.kLeftElbow.value, + # Right arm + G1_29_JointIndex.kRightShoulderPitch.value, + G1_29_JointIndex.kRightShoulderRoll.value, + G1_29_JointIndex.kRightShoulderYaw.value, + G1_29_JointIndex.kRightElbow.value, + ] + return motor_index.value in weak_motors + + def _is_wrist_motor(self, motor_index): + wrist_motors = [ + G1_29_JointIndex.kLeftWristRoll.value, + G1_29_JointIndex.kLeftWristPitch.value, + G1_29_JointIndex.kLeftWristyaw.value, + G1_29_JointIndex.kRightWristRoll.value, + G1_29_JointIndex.kRightWristPitch.value, + G1_29_JointIndex.kRightWristYaw.value, + ] + return motor_index.value in wrist_motors diff --git a/unitree_deploy/unitree_deploy/robot_devices/arm/g1_arm_ik.py b/unitree_deploy/unitree_deploy/robot_devices/arm/g1_arm_ik.py new file mode 100644 index 0000000..9781401 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/arm/g1_arm_ik.py @@ -0,0 +1,280 @@ +import casadi +import meshcat.geometry as mg +import numpy as np +import pinocchio as pin +from pinocchio import casadi as cpin +from pinocchio.visualize import MeshcatVisualizer + +from unitree_deploy.utils.weighted_moving_filter import WeightedMovingFilter + + +class G1_29_ArmIK: + def __init__(self, unit_test=False, visualization=False): + np.set_printoptions(precision=5, suppress=True, linewidth=200) + + self.unit_test = unit_test + self.visualization = visualization + + if not self.unit_test: + self.robot = pin.RobotWrapper.BuildFromURDF( + "unitree_deploy/robot_devices/assets/g1/g1_body29_hand14.urdf", + "unitree_deploy/robot_devices/assets/g1/", + ) + else: + self.robot = pin.RobotWrapper.BuildFromURDF( + "unitree_deploy/robot_devices/assets/g1/g1_body29_hand14.urdf", + "unitree_deploy/robot_devices/assets/g1/", + ) # for test + + self.mixed_jointsToLockIDs = [ + "left_hip_pitch_joint", + "left_hip_roll_joint", + "left_hip_yaw_joint", + "left_knee_joint", + "left_ankle_pitch_joint", + "left_ankle_roll_joint", + "right_hip_pitch_joint", + "right_hip_roll_joint", + "right_hip_yaw_joint", + "right_knee_joint", + "right_ankle_pitch_joint", + "right_ankle_roll_joint", + "waist_yaw_joint", + "waist_roll_joint", + "waist_pitch_joint", + "left_hand_thumb_0_joint", + "left_hand_thumb_1_joint", + "left_hand_thumb_2_joint", + "left_hand_middle_0_joint", + "left_hand_middle_1_joint", + "left_hand_index_0_joint", + "left_hand_index_1_joint", + "right_hand_thumb_0_joint", + "right_hand_thumb_1_joint", + "right_hand_thumb_2_joint", + "right_hand_index_0_joint", + "right_hand_index_1_joint", + "right_hand_middle_0_joint", + "right_hand_middle_1_joint", + ] + + self.reduced_robot = self.robot.buildReducedRobot( + list_of_joints_to_lock=self.mixed_jointsToLockIDs, + reference_configuration=np.array([0.0] * self.robot.model.nq), + ) + + self.reduced_robot.model.addFrame( + pin.Frame( + "L_ee", + self.reduced_robot.model.getJointId("left_wrist_yaw_joint"), + pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T), + pin.FrameType.OP_FRAME, + ) + ) + + self.reduced_robot.model.addFrame( + pin.Frame( + "R_ee", + self.reduced_robot.model.getJointId("right_wrist_yaw_joint"), + pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T), + pin.FrameType.OP_FRAME, + ) + ) + + # for i in range(self.reduced_robot.model.nframes): + # frame = self.reduced_robot.model.frames[i] + # frame_id = self.reduced_robot.model.getFrameId(frame.name) + + # Creating Casadi models and data for symbolic computing + self.cmodel = cpin.Model(self.reduced_robot.model) + self.cdata = self.cmodel.createData() + + # Creating symbolic variables + self.cq = casadi.SX.sym("q", self.reduced_robot.model.nq, 1) + self.cTf_l = casadi.SX.sym("tf_l", 4, 4) + self.cTf_r = casadi.SX.sym("tf_r", 4, 4) + cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq) + + # Get the hand joint ID and define the error function + self.L_hand_id = self.reduced_robot.model.getFrameId("L_ee") + self.R_hand_id = self.reduced_robot.model.getFrameId("R_ee") + + self.translational_error = casadi.Function( + "translational_error", + [self.cq, self.cTf_l, self.cTf_r], + [ + casadi.vertcat( + self.cdata.oMf[self.L_hand_id].translation - self.cTf_l[:3, 3], + self.cdata.oMf[self.R_hand_id].translation - self.cTf_r[:3, 3], + ) + ], + ) + self.rotational_error = casadi.Function( + "rotational_error", + [self.cq, self.cTf_l, self.cTf_r], + [ + casadi.vertcat( + cpin.log3(self.cdata.oMf[self.L_hand_id].rotation @ self.cTf_l[:3, :3].T), + cpin.log3(self.cdata.oMf[self.R_hand_id].rotation @ self.cTf_r[:3, :3].T), + ) + ], + ) + + # Defining the optimization problem + self.opti = casadi.Opti() + self.var_q = self.opti.variable(self.reduced_robot.model.nq) + self.var_q_last = self.opti.parameter(self.reduced_robot.model.nq) # for smooth + self.param_tf_l = self.opti.parameter(4, 4) + self.param_tf_r = self.opti.parameter(4, 4) + self.translational_cost = casadi.sumsqr(self.translational_error(self.var_q, self.param_tf_l, self.param_tf_r)) + self.rotation_cost = casadi.sumsqr(self.rotational_error(self.var_q, self.param_tf_l, self.param_tf_r)) + self.regularization_cost = casadi.sumsqr(self.var_q) + self.smooth_cost = casadi.sumsqr(self.var_q - self.var_q_last) + + # Setting optimization constraints and goals + self.opti.subject_to( + self.opti.bounded( + self.reduced_robot.model.lowerPositionLimit, + self.var_q, + self.reduced_robot.model.upperPositionLimit, + ) + ) + self.opti.minimize( + 50 * self.translational_cost + self.rotation_cost + 0.02 * self.regularization_cost + 0.1 * self.smooth_cost + ) + + opts = { + "ipopt": {"print_level": 0, "max_iter": 50, "tol": 1e-6}, + "print_time": False, # print or not + "calc_lam_p": False, # https://github.com/casadi/casadi/wiki/FAQ:-Why-am-I-getting-%22NaN-detected%22in-my-optimization%3F + } + self.opti.solver("ipopt", opts) + + self.init_data = np.zeros(self.reduced_robot.model.nq) + self.smooth_filter = WeightedMovingFilter(np.array([0.4, 0.3, 0.2, 0.1]), 14) + self.vis = None + + if self.visualization: + # Initialize the Meshcat visualizer for visualization + self.vis = MeshcatVisualizer( + self.reduced_robot.model, self.reduced_robot.collision_model, self.reduced_robot.visual_model + ) + self.vis.initViewer(open=True) + self.vis.loadViewerModel("pinocchio") + self.vis.displayFrames(True, frame_ids=[101, 102], axis_length=0.15, axis_width=5) + self.vis.display(pin.neutral(self.reduced_robot.model)) + + # Enable the display of end effector target frames with short axis lengths and greater width. + frame_viz_names = ["L_ee_target", "R_ee_target"] + frame_axis_positions = ( + np.array([[0, 0, 0], [1, 0, 0], [0, 0, 0], [0, 1, 0], [0, 0, 0], [0, 0, 1]]).astype(np.float32).T + ) + frame_axis_colors = ( + np.array([[1, 0, 0], [1, 0.6, 0], [0, 1, 0], [0.6, 1, 0], [0, 0, 1], [0, 0.6, 1]]).astype(np.float32).T + ) + axis_length = 0.1 + axis_width = 10 + for frame_viz_name in frame_viz_names: + self.vis.viewer[frame_viz_name].set_object( + mg.LineSegments( + mg.PointsGeometry( + position=axis_length * frame_axis_positions, + color=frame_axis_colors, + ), + mg.LineBasicMaterial( + linewidth=axis_width, + vertexColors=True, + ), + ) + ) + + # If the robot arm is not the same size as your arm :) + def scale_arms(self, human_left_pose, human_right_pose, human_arm_length=0.60, robot_arm_length=0.75): + scale_factor = robot_arm_length / human_arm_length + robot_left_pose = human_left_pose.copy() + robot_right_pose = human_right_pose.copy() + robot_left_pose[:3, 3] *= scale_factor + robot_right_pose[:3, 3] *= scale_factor + return robot_left_pose, robot_right_pose + + def solve_ik(self, left_wrist, right_wrist, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): + if current_lr_arm_motor_q is not None: + self.init_data = current_lr_arm_motor_q + self.opti.set_initial(self.var_q, self.init_data) + + # left_wrist, right_wrist = self.scale_arms(left_wrist, right_wrist) + if self.visualization: + self.vis.viewer["L_ee_target"].set_transform(left_wrist) # for visualization + self.vis.viewer["R_ee_target"].set_transform(right_wrist) # for visualization + + self.opti.set_value(self.param_tf_l, left_wrist) + self.opti.set_value(self.param_tf_r, right_wrist) + self.opti.set_value(self.var_q_last, self.init_data) # for smooth + + try: + self.opti.solve() + # sol = self.opti.solve_limited() + + sol_q = self.opti.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + v = current_lr_arm_motor_dq * 0.0 if current_lr_arm_motor_dq is not None else (sol_q - self.init_data) * 0.0 + + self.init_data = sol_q + + sol_tauff = pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + sol_q, + v, + np.zeros(self.reduced_robot.model.nv), + ) + + if self.visualization: + self.vis.display(sol_q) # for visualization + + return sol_q, sol_tauff + + except Exception as e: + print(f"ERROR in convergence, plotting debug info.{e}") + + sol_q = self.opti.debug.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + v = current_lr_arm_motor_dq * 0.0 if current_lr_arm_motor_dq is not None else (sol_q - self.init_data) * 0.0 + + self.init_data = sol_q + + sol_tauff = pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + sol_q, + v, + np.zeros(self.reduced_robot.model.nv), + ) + + print( + f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}" + ) + if self.visualization: + self.vis.display(sol_q) # for visualization + + # return sol_q, sol_tauff + return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv) + + def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): + try: + sol_tauff = pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + current_lr_arm_motor_q, + np.zeros(14), + np.zeros(self.reduced_robot.model.nv), + ) + return sol_tauff + + except Exception as e: + print(f"ERROR in convergence, plotting debug info.{e}") + return np.zeros(self.reduced_robot.model.nv) \ No newline at end of file diff --git a/unitree_deploy/unitree_deploy/robot_devices/arm/utils.py b/unitree_deploy/unitree_deploy/robot_devices/arm/utils.py new file mode 100644 index 0000000..3a489d4 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/arm/utils.py @@ -0,0 +1,63 @@ +from typing import Protocol + +from unitree_deploy.robot_devices.arm.configs import ArmConfig, G1ArmConfig, Z1ArmConfig, Z1DualArmConfig + + +class Arm(Protocol): + def connect(self): ... + def disconnect(self): ... + def motor_names(self): ... + + def read_current_motor_q(self): ... + def read_current_arm_q(self): ... + def read_current_arm_dq(self): ... + def write_arm(self): ... + + def arm_ik(self): ... + def arm_fk(self): ... + def go_start(self): ... + def go_home(self): ... + + +def make_arm_motors_buses_from_configs(armconfig: dict[str, ArmConfig]) -> list[Arm]: + arm_motors_buses = {} + + for key, cfg in armconfig.items(): + if cfg.type == "z1": + from unitree_deploy.robot_devices.arm.z1_arm import Z1ArmController + + arm_motors_buses[key] = Z1ArmController(cfg) + elif cfg.type == "g1": + from unitree_deploy.robot_devices.arm.g1_arm import G1_29_ArmController + + arm_motors_buses[key] = G1_29_ArmController(cfg) + elif cfg.type == "z1_dual": + from unitree_deploy.robot_devices.arm.z1_dual_arm import Z1_12_ArmController + + arm_motors_buses[key] = Z1_12_ArmController(cfg) + else: + raise ValueError(f"The motor type '{cfg.type}' is not valid.") + + return arm_motors_buses + + +def make_arm_motors_bus(arm_type: str, **kwargs) -> Arm: + if arm_type == "z1": + from unitree_deploy.robot_devices.arm.z1_arm import Z1ArmController + + config = Z1ArmConfig(**kwargs) + return Z1ArmController(config) + + elif arm_type == "z1_dual": + from unitree_deploy.robot_devices.arm.z1_dual_arm import Z1_12_ArmController + + config = Z1DualArmConfig(**kwargs) + return Z1_12_ArmController(config) + + elif arm_type == "g1": + from unitree_deploy.robot_devices.arm.g1_arm import G1_29_ArmController + + config = G1ArmConfig(**kwargs) + return G1_29_ArmController(config) + else: + raise ValueError(f"The motor type '{arm_type}' is not valid.") diff --git a/unitree_deploy/unitree_deploy/robot_devices/arm/z1_arm.py b/unitree_deploy/unitree_deploy/robot_devices/arm/z1_arm.py new file mode 100644 index 0000000..909c8f8 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/arm/z1_arm.py @@ -0,0 +1,294 @@ +import os +import sys +import threading +import time +from typing import Callable + +import numpy as np + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) +import unitree_arm_interface as unitree_z1 # type: ignore + +from unitree_deploy.robot_devices.arm.arm_indexs import Z1GripperArmJointIndex +from unitree_deploy.robot_devices.arm.configs import Z1ArmConfig +from unitree_deploy.robot_devices.arm.z1_arm_ik import Z1_Arm_IK +from unitree_deploy.robot_devices.robots_devices_utils import ( + DataBuffer, + MotorState, + Robot_Num_Motors, + RobotDeviceAlreadyConnectedError, +) +from unitree_deploy.utils.joint_trajcetory_inter import JointTrajectoryInterpolator +from unitree_deploy.utils.rich_logger import RichLogger + + +class Z1LowState: + def __init__(self) -> None: + self.motor_state: list[MotorState] = [MotorState() for _ in range(Robot_Num_Motors.Z1_7_Num_Motors)] + + +class Z1ArmController: + def __init__(self, config: Z1ArmConfig): + self.motors = config.motors + + self.init_pose = config.init_pose + self.unit_test = config.unit_test + self.control_dt = config.control_dt + + self.robot_kp = config.robot_kp + self.robot_kd = config.robot_kd + self.max_pos_speed = config.max_pos_speed + self.log_level = config.log_level + + self.q_target = self.init_pose + self.dq_target = np.zeros(len(Z1GripperArmJointIndex) - 1, dtype=np.float16) + self.ddq_target = np.zeros(len(Z1GripperArmJointIndex) - 1, dtype=np.float16) + self.tauff_target = np.zeros(len(Z1GripperArmJointIndex) - 1, dtype=np.float16) + self.ftip_target = np.zeros(len(Z1GripperArmJointIndex) - 1, dtype=np.float16) + self.time_target = time.monotonic() + + self.DELTA_GRIPPER_CMD = 5.0 / 20.0 / 25.6 + self.arm_cmd = "schedule_waypoint" + + self.ctrl_lock = threading.Lock() + + self.lowstate_buffer = DataBuffer() + self.z1_arm_ik = Z1_Arm_IK(unit_test=self.unit_test, visualization=False) + self.logger = RichLogger(self.log_level) + + self.is_connected = False + self.grasped = False + + @property + def motor_names(self) -> list[str]: + return list(self.motors.keys()) + + @property + def motor_models(self) -> list[str]: + return [model for _, model in self.motors.values()] + + @property + def motor_indices(self) -> list[int]: + return [idx for idx, _ in self.motors.values()] + + def _start_daemon_thread(self, target_fn: Callable[[], None], name: str | None = None) -> threading.Thread: + thread = threading.Thread(target=target_fn, name=name) + thread.daemon = True + thread.start() + return thread + + def connect(self): + try: + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + "Z1_Arm is already connected. Do not run `robot.connect()` twice." + ) + # Initialize arms + self.z1 = unitree_z1.ArmInterface() + self.z1_model = self.z1._ctrlComp.armModel + self.z1.setFsmLowcmd() + self.z1.lowcmd.setControlGain(self.robot_kp, self.robot_kd) + self.z1.sendRecv() + + self.subscribe_thread = self._start_daemon_thread( + self._subscribe_motor_state, name="z1._subscribe_motor_state" + ) + while not self.lowstate_buffer.get_data(): + time.sleep(0.01) + self.logger.warning("[Z1_ArmController] Waiting Get Data...") + + self.publish_thread = self._start_daemon_thread(self._ctrl_motor_state, name="z1._ctrl_motor_state") + self.is_connected = True + + except Exception as e: + self.disconnect() + self.logger.error(f"❌ Error in Z1ArmController.connect: {e}") + + def _subscribe_motor_state(self): + try: + while True: + lowstate = Z1LowState() + for motor_id in range(Robot_Num_Motors.Z1_7_Num_Motors - 1): + lowstate.motor_state[motor_id].q = self.z1.lowstate.getQ()[motor_id] + lowstate.motor_state[motor_id].dq = self.z1.lowstate.getQd()[motor_id] + + gripper_q = self.z1.lowstate.getGripperQ() + lowstate.motor_state[Robot_Num_Motors.Z1_7_Num_Motors - 1].q = gripper_q + lowstate.motor_state[Robot_Num_Motors.Z1_7_Num_Motors - 1].dq = 0.0 + + self.lowstate_buffer.set_data(lowstate) + time.sleep(self.control_dt) + + except Exception as e: + self.disconnect() + self.logger.error(f"❌ Error in Z1ArmController._subscribe_motor_state: {e}") + + def _update_z1_arm( + self, + q: np.ndarray, + qd: np.ndarray | None = None, + qdd: np.ndarray | None = None, + ftip: np.ndarray | None = None, + tau: np.ndarray | None = None, + ): + """Update the state and command of a given robotic arm.""" + current_gripper_q = self.read_current_gripper_q() + self.z1.q = q[: len(Z1GripperArmJointIndex) - 1] + self.z1.qd = self.dq_target if qd is None else qd + qdd = self.ddq_target if qdd is None else qdd + ftip = self.ftip_target if ftip is None else ftip + self.z1.tau = self.z1_model.inverseDynamics(self.z1.q, self.z1.qd, qdd, ftip) if tau is None else tau + self.z1.setArmCmd(self.z1.q, self.z1.qd, self.z1.tau) + + gripper_q = q[len(Z1GripperArmJointIndex) - 1] + self.z1.gripperQ = np.clip( + gripper_q, + current_gripper_q - self.DELTA_GRIPPER_CMD * 3, + current_gripper_q + self.DELTA_GRIPPER_CMD * 3, + ) + # self.z1.gripperQ = np.clip(gripper_q, current_gripper_q - self.DELTA_GRIPPER_CMD, current_gripper_q + self.DELTA_GRIPPER_CMD) if self.grasped else np.clip(gripper_q, current_gripper_q - self.DELTA_GRIPPER_CMD*4, current_gripper_q + self.DELTA_GRIPPER_CMD*4) # np.clip(gripper_q, current_gripper_q - self.DELTA_GRIPPER_CMD*3, current_gripper_q + self.DELTA_GRIPPER_CMD*3) + self.z1.setGripperCmd(self.z1.gripperQ, self.z1.gripperQd, self.z1.gripperTau) + self.z1.sendRecv() + time.sleep(self.control_dt) + self.grasped = abs(self.read_current_gripper_q() - current_gripper_q) < self.DELTA_GRIPPER_CMD / 12.0 + + def _drive_to_waypoint(self, target_pose: np.ndarray, t_insert_time: float): + curr_time = time.monotonic() + self.control_dt + t_insert = curr_time + t_insert_time + self.pose_interp = self.pose_interp.drive_to_waypoint( + pose=target_pose, + time=t_insert, + curr_time=curr_time, + max_pos_speed=self.max_pos_speed, + ) + + while time.monotonic() < t_insert: + self._update_z1_arm(self.pose_interp(time.monotonic())) + + def _schedule_waypoint( + self, + arm_q_target: np.ndarray, + arm_time_target: float, + t_now: float, + start_time: float, + last_waypoint_time: float, + arm_tauff_target: np.ndarray | None = None, + ) -> float: + target_time = time.monotonic() - time.perf_counter() + arm_time_target + curr_time = t_now + self.control_dt + target_time = max(target_time, curr_time + self.control_dt) + + self.pose_interp = self.pose_interp.schedule_waypoint( + pose=arm_q_target, + time=target_time, + max_pos_speed=self.max_pos_speed, + curr_time=curr_time, + last_waypoint_time=last_waypoint_time, + ) + last_waypoint_time = target_time + self._update_z1_arm(q=self.pose_interp(t_now), tau=arm_tauff_target) + + def _ctrl_motor_state(self): + try: + self.pose_interp = JointTrajectoryInterpolator( + times=[time.monotonic()], + joint_positions=[self.read_current_arm_q()], + ) + self.go_start() + + arm_q_target = self.read_current_arm_q() + arm_tauff_target = self.tauff_target + arm_time_target = time.monotonic() + arm_cmd = "schedule_waypoint" + last_waypoint_time = time.monotonic() + + while True: + start_time = time.perf_counter() + t_now = time.monotonic() + with self.ctrl_lock: + arm_q_target = self.q_target + arm_tauff_target = self.tauff_target + arm_time_target = self.time_target + arm_cmd = self.arm_cmd + + if arm_cmd == "drive_to_waypoint": + self._drive_to_waypoint(target_pose=arm_q_target, t_insert_time=0.8) + + elif arm_cmd == "schedule_waypoint": + self._schedule_waypoint( + arm_q_target=arm_q_target, + arm_time_target=arm_time_target, + t_now=t_now, + start_time=start_time, + last_waypoint_time=last_waypoint_time, + arm_tauff_target=arm_tauff_target, + ) + + except Exception as e: + self.disconnect() + self.logger.error(f"❌ Error in Z1ArmController._ctrl_motor_state: {e}") + + def read_current_arm_q(self) -> np.ndarray: + """Return current state q of the left and right arm motors.""" + return np.array([self.lowstate_buffer.get_data().motor_state[id].q for id in Z1GripperArmJointIndex]) + + def read_current_arm_q_without_gripper(self) -> np.ndarray: + """Return current state q of the left and right arm motors.""" + return np.array([self.lowstate_buffer.get_data().motor_state[id].q for id in list(Z1GripperArmJointIndex)[:-1]]) + + def read_current_gripper_q(self) -> np.ndarray: + """Return current state q of the left and right arm motors.""" + return np.array([self.lowstate_buffer.get_data().motor_state[list(Z1GripperArmJointIndex)[-1].value].q]) + + def read_current_arm_dq(self) -> np.ndarray: + """Return current state dq of the left and right arm motors.""" + return np.array([self.lowstate_buffer.get_data().motor_state[id].dq for id in Z1GripperArmJointIndex]) + + def read_current_arm_dq_without_gripper(self) -> np.ndarray: + """Return current state dq of the left and right arm motors.""" + return np.array( + [self.lowstate_buffer.get_data().motor_state[id].dq for id in list(Z1GripperArmJointIndex)[:-1]] + ) + + def write_arm( + self, + q_target: list[float] | np.ndarray, + tauff_target: list[float] | np.ndarray = None, + time_target: float | None = None, + cmd_target: str | None = None, + ): + """Set control target values q & tau of the left and right arm motors.""" + with self.ctrl_lock: + self.q_target = q_target + self.tauff_target = tauff_target + self.time_target = time_target + self.arm_cmd = cmd_target + + def arm_ik(self, ee_target: list[float] | np.ndarray) -> tuple[np.ndarray, np.ndarray] | None: + return self.z1_arm_ik.solve_ik( + ee_target, self.read_current_arm_q_without_gripper(), self.read_current_arm_dq_without_gripper() + ) + + def arm_fk(self, q: np.ndarray | None = None) -> np.ndarray | None: + return self.z1_model.forwardKinematics( + q if q is not None else self.read_current_arm_q(), len(Z1GripperArmJointIndex) + ) + + def go_start(self): + self._drive_to_waypoint(target_pose=self.init_pose, t_insert_time=1.0) + self.logger.success("Go Start OK!\n") + + def go_home(self): + self.z1.loopOn() + self.z1.backToStart() + self.z1.loopOff() + time.sleep(0.5) + self.logger.success("Go Home OK!\n") + + def disconnect(self): + self.is_connected = False + self.go_home() + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/unitree_deploy/unitree_deploy/robot_devices/arm/z1_arm_ik.py b/unitree_deploy/unitree_deploy/robot_devices/arm/z1_arm_ik.py new file mode 100644 index 0000000..888860b --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/arm/z1_arm_ik.py @@ -0,0 +1,253 @@ +import time + +import casadi +import meshcat.geometry as mg +import numpy as np +import pinocchio as pin +from pinocchio import casadi as cpin +from pinocchio.visualize import MeshcatVisualizer + +from unitree_deploy.utils.weighted_moving_filter import WeightedMovingFilter + + +class Z1_Arm_IK: + def __init__(self, unit_test=False, visualization=False): + np.set_printoptions(precision=5, suppress=True, linewidth=200) + + self.unit_test = unit_test + self.visualization = visualization + + self.robot = pin.RobotWrapper.BuildFromURDF( + "unitree_deploy/robot_devices/assets/z1/z1.urdf", "unitree_deploy/robot_devices/assets/z1/" + ) + self.mixed_jointsToLockIDs = ["base_static_joint"] + + self.reduced_robot = self.robot.buildReducedRobot( + list_of_joints_to_lock=self.mixed_jointsToLockIDs, + reference_configuration=np.array([0.0] * self.robot.model.nq), + ) + + self.reduced_robot.model.addFrame( + pin.Frame( + "ee", + self.reduced_robot.model.getJointId("joint6"), + pin.SE3(np.eye(3), np.array([0.15, 0, 0]).T), + pin.FrameType.OP_FRAME, + ) + ) + + # for i in range(self.reduced_robot.model.nframes): + # frame = self.reduced_robot.model.frames[i] + # frame_id = self.reduced_robot.model.getFrameId(frame.name) + # print(f"Frame ID: {frame_id}, Name: {frame.name}") + + # Creating Casadi models and data for symbolic computing + self.cmodel = cpin.Model(self.reduced_robot.model) + self.cdata = self.cmodel.createData() + + self.cq = casadi.SX.sym("q", self.reduced_robot.model.nq, 1) + self.cTf = casadi.SX.sym("tf", 4, 4) + + cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq) + + self.EE_ID = self.reduced_robot.model.getFrameId("link06") + self.translational_error = casadi.Function( + "translational_error", + [self.cq, self.cTf], + [ + casadi.vertcat( + self.cdata.oMf[self.EE_ID].translation - self.cTf[:3, 3], + ) + ], + ) + self.rotational_error = casadi.Function( + "rotational_error", + [self.cq, self.cTf], + [ + casadi.vertcat( + cpin.log3(self.cdata.oMf[self.EE_ID].rotation @ self.cTf[:3, :3].T), + ) + ], + ) + + self.opti = casadi.Opti() + self.var_q = self.opti.variable(self.reduced_robot.model.nq) + self.var_q_last = self.opti.parameter(self.reduced_robot.model.nq) # for smooth + self.param_tf = self.opti.parameter(4, 4) + self.translational_cost = casadi.sumsqr(self.translational_error(self.var_q, self.param_tf)) + self.rotation_cost = casadi.sumsqr(self.rotational_error(self.var_q, self.param_tf)) + self.regularization_cost = casadi.sumsqr(self.var_q) + self.smooth_cost = casadi.sumsqr(self.var_q - self.var_q_last) + + # Setting optimization constraints and goals + self.opti.subject_to( + self.opti.bounded( + self.reduced_robot.model.lowerPositionLimit, + self.var_q, + self.reduced_robot.model.upperPositionLimit, + ) + ) + self.opti.minimize( + 50 * self.translational_cost + + self.rotation_cost + + 0.02 * self.regularization_cost + + 0.1 * self.smooth_cost + ) + # self.opti.minimize(20 * self.cost + self.regularization_cost) + + opts = { + "ipopt": {"print_level": 0, "max_iter": 50, "tol": 1e-6}, + "print_time": False, # print or not + "calc_lam_p": False, # https://github.com/casadi/casadi/wiki/FAQ:-Why-am-I-getting-%22NaN-detected%22in-my-optimization%3F + } + self.opti.solver("ipopt", opts) + + self.init_data = np.zeros(self.reduced_robot.model.nq) + self.smooth_filter = WeightedMovingFilter(np.array([0.4, 0.3, 0.2, 0.1]), 6) + + self.vis = None + + if self.visualization: + # Initialize the Meshcat visualizer for visualization + self.vis = MeshcatVisualizer( + self.reduced_robot.model, self.reduced_robot.collision_model, self.reduced_robot.visual_model + ) + self.vis.initViewer(open=True) + self.vis.loadViewerModel("pinocchio") + self.vis.displayFrames(True, frame_ids=[101, 102], axis_length=0.15, axis_width=5) + self.vis.display(pin.neutral(self.reduced_robot.model)) + + # Enable the display of end effector target frames with short axis lengths and greater width. + frame_viz_names = ["ee_target"] + frame_axis_positions = ( + np.array([[0, 0, 0], [1, 0, 0], [0, 0, 0], [0, 1, 0], [0, 0, 0], [0, 0, 1]]) + .astype(np.float32) + .T + ) + frame_axis_colors = ( + np.array([[1, 0, 0], [1, 0.6, 0], [0, 1, 0], [0.6, 1, 0], [0, 0, 1], [0, 0.6, 1]]) + .astype(np.float32) + .T + ) + axis_length = 0.1 + axis_width = 10 + for frame_viz_name in frame_viz_names: + self.vis.viewer[frame_viz_name].set_object( + mg.LineSegments( + mg.PointsGeometry( + position=axis_length * frame_axis_positions, + color=frame_axis_colors, + ), + mg.LineBasicMaterial( + linewidth=axis_width, + vertexColors=True, + ), + ) + ) + + def solve_ik(self, wrist, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): + if current_lr_arm_motor_q is not None: + self.init_data = current_lr_arm_motor_q + self.opti.set_initial(self.var_q, self.init_data) + + # left_wrist, right_wrist = self.scale_arms(left_wrist, right_wrist) + if self.visualization: + self.vis.viewer["ee_target"].set_transform(wrist) # for visualization + + self.opti.set_value(self.param_tf, wrist) + self.opti.set_value(self.var_q_last, self.init_data) # for smooth + + try: + self.opti.solve() + # sol = self.opti.solve_limited() + + sol_q = self.opti.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + v = current_lr_arm_motor_dq * 0.0 if current_lr_arm_motor_dq is not None else (sol_q - self.init_data) * 0.0 + self.init_data = sol_q + sol_tauff = pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + sol_q, + v, + np.zeros(self.reduced_robot.model.nv), + ) + + if self.visualization: + self.vis.display(sol_q) # for visualization + + return sol_q, sol_tauff + + except Exception as e: + print(f"ERROR in convergence, plotting debug info.{e}") + + sol_q = self.opti.debug.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + v = current_lr_arm_motor_dq * 0.0 if current_lr_arm_motor_dq is not None else (sol_q - self.init_data) * 0.0 + + self.init_data = sol_q + sol_tauff = pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + sol_q, + v, + np.zeros(self.reduced_robot.model.nv), + ) + if self.visualization: + self.vis.display(sol_q) # for visualization + + # return sol_q, sol_tauff + return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv) + + +if __name__ == "__main__": + arm_ik = Z1_Arm_IK(unit_test=True, visualization=True) + + # initial positon + L_tf_target = pin.SE3( + pin.Quaternion(1, 0, 0, 0), + np.array([0.25, 0, 0.2]), + ) + + rotation_speed = 0.02 + noise_amplitude_translation = 0.002 + noise_amplitude_rotation = 0.1 + + user_input = input("Please enter the start signal (enter 's' to start the subsequent program):\n") + if user_input.lower() == "s": + step = 0 + while True: + # Apply rotation noise with bias towards y and z axes + rotation_noise_l = pin.Quaternion( + np.cos(np.random.normal(0, noise_amplitude_rotation) / 2), + 0, + np.random.normal(0, noise_amplitude_rotation / 2), + 0, + ).normalized() # y bias + + if step <= 120: + angle = rotation_speed * step + L_tf_target.rotation = ( + rotation_noise_l * pin.Quaternion(np.cos(angle / 2), 0, np.sin(angle / 2), 0) + ).toRotationMatrix() # y axis + L_tf_target.translation += np.array([0.001, 0.001, 0.001]) + np.random.normal( + 0, noise_amplitude_translation, 3 + ) + else: + angle = rotation_speed * (240 - step) + L_tf_target.rotation = ( + rotation_noise_l * pin.Quaternion(np.cos(angle / 2), 0, np.sin(angle / 2), 0) + ).toRotationMatrix() # y axis + L_tf_target.translation -= np.array([0.001, 0.001, 0.001]) + np.random.normal( + 0, noise_amplitude_translation, 3 + ) + + sol_q, _ = arm_ik.solve_ik(L_tf_target.homogeneous) + step += 1 + if step > 240: + step = 0 + time.sleep(0.01) diff --git a/unitree_deploy/unitree_deploy/robot_devices/arm/z1_dual_arm.py b/unitree_deploy/unitree_deploy/robot_devices/arm/z1_dual_arm.py new file mode 100644 index 0000000..f7afa78 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/arm/z1_dual_arm.py @@ -0,0 +1,347 @@ +import os +import sys +import threading +import time +from typing import Callable + +import numpy as np + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)))) +import unitree_arm_interface as unitree_z1 # type: ignore + +from unitree_deploy.robot_devices.arm.arm_indexs import Z1_12_JointArmIndex +from unitree_deploy.robot_devices.arm.configs import Z1DualArmConfig +from unitree_deploy.robot_devices.arm.z1_arm_ik import Z1_Arm_IK +from unitree_deploy.robot_devices.robots_devices_utils import ( + DataBuffer, + MotorState, + Robot_Num_Motors, + RobotDeviceAlreadyConnectedError, +) +from unitree_deploy.utils.joint_trajcetory_inter import JointTrajectoryInterpolator +from unitree_deploy.utils.rich_logger import log_error, log_info, log_success, log_warning + + +class Z1_12_LowState: + def __init__(self): + self.motor_state = [MotorState() for _ in range(Robot_Num_Motors.Z1_12_Num_Motors)] + + +class Z1_12_ArmController: + def __init__(self, config: Z1DualArmConfig): + log_info("Initialize Z1_12_ArmController...") + + self.left_robot_ip = config.left_robot_ip + self.left_robot_port1 = config.left_robot_port1 + self.left_robot_port2 = config.left_robot_port2 + self.right_robot_ip = config.right_robot_ip + self.right_robot_port1 = config.right_robot_port1 + self.right_robot_port2 = config.right_robot_port2 + + self.robot_kp = config.robot_kp + self.robot_kd = config.robot_kd + + self.max_pos_speed = config.max_pos_speed + self.init_pose_left = np.array(config.init_pose_left) + self.init_pose_right = np.array(config.init_pose_right) + self.init_pose = np.concatenate((self.init_pose_left, self.init_pose_right), axis=0) + self.unit_test = config.unit_test + self.control_dt = config.control_dt + self.motors = config.motors + self.mock = config.mock + + self.q_target = np.concatenate((self.init_pose_left, self.init_pose_right)) + self.tauff_target = np.zeros(len(Z1_12_JointArmIndex), dtype=np.float64) + self.dq_target = np.zeros(len(Z1_12_JointArmIndex) // 2, dtype=np.float64) + self.ddq_target = np.zeros(len(Z1_12_JointArmIndex) // 2, dtype=np.float64) + self.ftip_target = np.zeros(len(Z1_12_JointArmIndex) // 2, dtype=np.float64) + self.time_target = time.monotonic() + self.arm_cmd = "schedule_waypoint" + + self.ctrl_lock = threading.Lock() + self.lowstate_buffer = DataBuffer() + self.stop_event = threading.Event() + + self.z1_left_arm_ik = Z1_Arm_IK(unit_test=self.unit_test, visualization=False) + self.z1_right_arm_ik = Z1_Arm_IK(unit_test=self.unit_test, visualization=False) + + self.arm_indices_len = len(Z1_12_JointArmIndex) // 2 + + self.is_connected = False + + @property + def motor_names(self) -> list[str]: + return list(self.motors.keys()) + + @property + def motor_models(self) -> list[str]: + return [model for _, model in self.motors.values()] + + @property + def motor_indices(self) -> list[int]: + return [idx for idx, _ in self.motors.values()] + + def _start_daemon_thread(self, target_fn: Callable[[], None], name: str | None = None) -> threading.Thread: + thread = threading.Thread(target=target_fn, name=name) + thread.daemon = True + thread.start() + return thread + + def initialize_arm(self, ip: str, port1: int, port2: int, name: str): + """Initialize z1.""" + arm = unitree_z1.ArmInterface(ip, port1, port2) + arm_model = arm._ctrlComp.armModel + arm.setFsmLowcmd() + return arm, arm_model + + def set_control_gains(self, kp, kd): + """Initialize kp kd.""" + for arm in [self.z1_left, self.z1_right]: + arm.lowcmd.setControlGain(kp, kd) + arm.sendRecv() + + def connect(self): + try: + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + "Z1_Dual_Arm is already connected. Do not run `robot.connect()` twice." + ) + # Initialize arms + self.z1_left, self.z1_left_model = self.initialize_arm( + self.left_robot_ip, self.left_robot_port1, self.left_robot_port2, "left" + ) + self.z1_right, self.z1_right_model = self.initialize_arm( + self.right_robot_ip, self.right_robot_port1, self.right_robot_port2, "right" + ) + + # Set control gains + self.set_control_gains(self.robot_kp, self.robot_kd) + + # initialize subscribe thread + self.subscribe_thread = self._start_daemon_thread(self._subscribe_motor_state, name="z1._subscribe_motor_state") + + while not self.lowstate_buffer.get_data(): + time.sleep(0.01) + log_warning("[Z1_12_ArmController] Waiting Get Data...") + + self.publish_thread = self._start_daemon_thread(self._ctrl_motor_state, name="z1_dual._ctrl_motor_state") + self.is_connected = True + + except Exception as e: + self.disconnect() + log_error(f"❌ Error in Z1_12_ArmController.connect: {e}") + + def _subscribe_motor_state(self): + while True: + msg = { + "q": np.concatenate([self.z1_left.lowstate.getQ(), self.z1_right.lowstate.getQ()], axis=0), + "dq": np.concatenate([self.z1_left.lowstate.getQd(), self.z1_right.lowstate.getQd()], axis=0), + } + if msg is not None: + lowstate = Z1_12_LowState() + for id in range(Robot_Num_Motors.Z1_12_Num_Motors): + lowstate.motor_state[id].q = msg["q"][id] + lowstate.motor_state[id].dq = msg["dq"][id] + self.lowstate_buffer.set_data(lowstate) + time.sleep(self.control_dt) + + def _update_z1_arm( + self, + arm, + arm_model, + q: np.ndarray, + qd: np.ndarray | None = None, + qdd: np.ndarray | None = None, + ftip: np.ndarray | None = None, + tau: np.ndarray | None = None, + ): + """Update the state and command of a given robotic arm.""" + arm.q = q + arm.qd = self.dq_target if qd is None else qd + qdd = self.ddq_target if qdd is None else qdd + ftip = self.ftip_target if ftip is None else ftip + arm.tau = arm_model.inverseDynamics(arm.q, arm.qd, qdd, ftip) if tau is None else tau + arm.setArmCmd(arm.q, arm.qd, arm.tau) + arm.sendRecv() + + def _drive_to_waypoint(self, target_pose: np.ndarray, t_insert_time: float): + curr_time = time.monotonic() + self.control_dt + t_insert = curr_time + t_insert_time + self.pose_interp = self.pose_interp.drive_to_waypoint( + pose=target_pose, + time=t_insert, + curr_time=curr_time, + max_pos_speed=self.max_pos_speed, + ) + + while time.monotonic() < t_insert: + self._update_z1_arm( + self.z1_left, self.z1_left_model, self.pose_interp(time.monotonic())[: self.arm_indices_len] + ) + self._update_z1_arm( + self.z1_right, self.z1_right_model, self.pose_interp(time.monotonic())[self.arm_indices_len :] + ) + time.sleep(self.control_dt) + + def _schedule_waypoint( + self, + arm_q_target: np.ndarray, + arm_time_target: float, + t_now: float, + start_time: float, + last_waypoint_time: float, + arm_tauff_target: np.ndarray | None = None, + ) -> float: + target_time = time.monotonic() - time.perf_counter() + arm_time_target + curr_time = t_now + self.control_dt + target_time = max(target_time, curr_time + self.control_dt) + + self.pose_interp = self.pose_interp.schedule_waypoint( + pose=arm_q_target, + time=target_time, + max_pos_speed=self.max_pos_speed, + curr_time=curr_time, + last_waypoint_time=last_waypoint_time, + ) + last_waypoint_time = target_time + self._update_z1_arm( + arm=self.z1_left, + arm_model=self.z1_left_model, + q=self.pose_interp(t_now)[: self.arm_indices_len], + tau=arm_tauff_target[: self.arm_indices_len] if arm_tauff_target is not None else arm_tauff_target, + ) + self._update_z1_arm( + arm=self.z1_right, + arm_model=self.z1_right_model, + q=self.pose_interp(t_now)[self.arm_indices_len :], + tau=arm_tauff_target[self.arm_indices_len :] if arm_tauff_target is not None else arm_tauff_target, + ) + + time.sleep(max(0, self.control_dt - (time.perf_counter() - start_time))) + + def _ctrl_motor_state(self): + try: + self.pose_interp = JointTrajectoryInterpolator( + times=[time.monotonic()], + joint_positions=[self.read_current_arm_q()], + ) + + self.go_start() + + arm_q_target = self.read_current_arm_q() + arm_tauff_target = self.tauff_target + arm_time_target = time.monotonic() + arm_cmd = "schedule_waypoint" + + last_waypoint_time = time.monotonic() + + while True: + start_time = time.perf_counter() + t_now = time.monotonic() + + with self.ctrl_lock: + arm_q_target = self.q_target + arm_tauff_target = self.tauff_target + arm_time_target = self.time_target + arm_cmd = self.arm_cmd + + if arm_cmd == "drive_to_waypoint": + self._drive_to_waypoint(target_pose=arm_q_target, t_insert_time=0.8) + + elif arm_cmd == "schedule_waypoint": + self._schedule_waypoint( + arm_q_target=arm_q_target, + arm_time_target=arm_time_target, + t_now=t_now, + start_time=start_time, + last_waypoint_time=last_waypoint_time, + arm_tauff_target=arm_tauff_target, + ) + + except Exception as e: + self.disconnect() + log_error(f"❌ Error in Z1ArmController._ctrl_motor_state: {e}") + + def write_arm( + self, + q_target: list[float] | np.ndarray, + tauff_target: list[float] | np.ndarray = None, + time_target: float | None = None, + cmd_target: str | None = None, + ): + """Set control target values q & tau of the left and right arm motors.""" + with self.ctrl_lock: + self.q_target = q_target + self.tauff_target = tauff_target + self.time_target = time_target + self.arm_cmd = cmd_target + + def read_current_arm_q(self) -> np.ndarray | None: + """Return current state q of the left and right arm motors.""" + return np.array([self.lowstate_buffer.get_data().motor_state[id].q for id in Z1_12_JointArmIndex]) + + def read_current_arm_dq(self) -> np.ndarray | None: + """Return current state dq of the left and right arm motors.""" + return np.array([self.lowstate_buffer.get_data().motor_state[id].dq for id in Z1_12_JointArmIndex]) + + def arm_ik(self, l_tf_target, r_tf_target) -> np.ndarray | None: + current_lr_arm_q = self.read_current_arm_q() + current_lr_arm_dq = self.read_current_arm_dq() + + left_sol_q, left_sol_tauff = self.z1_left_arm_ik.solve_ik( + l_tf_target, + current_lr_arm_q[: self.arm_indices_len], + current_lr_arm_dq[: self.arm_indices_len], + ) + right_sol_q, right_sol_tauff = self.z1_right_arm_ik.solve_ik( + r_tf_target, + current_lr_arm_q[self.arm_indices_len :], + current_lr_arm_dq[self.arm_indices_len :], + ) + + sol_q = np.concatenate([left_sol_q, right_sol_q], axis=0) + sol_tauff = np.concatenate([left_sol_tauff, right_sol_tauff], axis=0) + + return sol_q, sol_tauff + + def arm_fk(self, left_q: np.ndarray | None = None, right_q: np.ndarray | None = None) -> np.ndarray | None: + left = self.z1_left_arm_ik.solve_fk( + left_q if left_q is not None else self.read_current_arm_q()[: self.arm_indices_len] + ) + right = self.z1_right_arm_ik.solve_fk( + right_q if right_q is not None else self.read_current_arm_q()[self.arm_indices_len :] + ) + + return left, right + + def go_start(self): + self._drive_to_waypoint(target_pose=self.init_pose, t_insert_time=1.0) + log_success("Go Start OK!\n") + + def go_home(self): + if self.mock: + self.stop_event.set() + # self.subscribe_thread.join() + # self.publish_thread.join() + time.sleep(1) + else: + self.is_connected = False + + self.z1_right.loopOn() + self.z1_right.backToStart() + self.z1_right.loopOff() + + self.z1_left.loopOn() + self.z1_left.backToStart() + self.z1_left.loopOff() + + time.sleep(0.5) + log_success("Go Home OK!\n") + + def disconnect(self): + self.is_connected = False + self.go_home() + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/unitree_deploy/unitree_deploy/robot_devices/assets/g1/.gitignore b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/.gitignore new file mode 100644 index 0000000..0a293e4 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/.gitignore @@ -0,0 +1,2 @@ +*.gv +*.pdf diff --git a/unitree_deploy/unitree_deploy/robot_devices/assets/g1/README.md b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/README.md new file mode 100644 index 0000000..f666cf9 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/README.md @@ -0,0 +1,33 @@ +# Unitree G1 Description (URDF & MJCF) + +## Overview + +This package includes a universal humanoid robot description (URDF & MJCF) for the [Unitree G1](https://www.unitree.com/g1/), developed by [Unitree Robotics](https://www.unitree.com/). + +MJCF/URDF for the G1 robot: + +| MJCF/URDF file name | `mode_machine` | Hip roll reduction ratio | Update status | dof#leg | dof#waist | dof#arm | dof#hand | +| ----------------------------- | :------------: | :----------------------: | ------------- | :-----: | :-------: | :-----: | :------: | +| `g1_23dof` | 1 | 14.5 | Beta | 6*2 | 1 | 5*2 | 0 | +| `g1_29dof` | 2 | 14.5 | Beta | 6*2 | 3 | 7*2 | 0 | +| `g1_29dof_with_hand` | 2 | 14.5 | Beta | 6*2 | 3 | 7*2 | 7*2 | +| `g1_29dof_lock_waist` | 3 | 14.5 | Beta | 6*2 | 1 | 7*2 | 0 | +| `g1_23dof_rev_1_0` | 4 | 22.5 | Up-to-date | 6*2 | 1 | 5*2 | 0 | +| `g1_29dof_rev_1_0` | 5 | 22.5 | Up-to-date | 6*2 | 3 | 7*2 | 0 | +| `g1_29dof_with_hand_rev_1_0` | 5 | 22.5 | Up-to-date | 6*2 | 3 | 7*2 | 7*2 | +| `g1_29dof_lock_waist_rev_1_0` | 6 | 22.5 | Up-to-date | 6*2 | 1 | 7*2 | 0 | +| `g1_dual_arm` | 9 | null | Up-to-date | 0 | 0 | 7*2 | 0 | + +## Visulization with [MuJoCo](https://github.com/google-deepmind/mujoco) + +1. Open MuJoCo Viewer + + ```bash + pip install mujoco + python -m mujoco.viewer + ``` + +2. Drag and drop the MJCF/URDF model file (`g1_XXX.xml`/`g1_XXX.urdf`) to the MuJoCo Viewer. + +## Note for teleoperate +g1_body29_hand14 is modified from [g1_29dof_with_hand_rev_1_0](https://github.com/unitreerobotics/unitree_ros/blob/master/robots/g1_description/g1_29dof_with_hand_rev_1_0.urdf) diff --git a/unitree_deploy/unitree_deploy/robot_devices/assets/g1/g1_body29.xml b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/g1_body29.xml new file mode 100644 index 0000000..93214dd --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/g1_body29.xml @@ -0,0 +1,381 @@ + + + + + diff --git a/unitree_deploy/unitree_deploy/robot_devices/assets/g1/g1_body29_hand14.urdf b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/g1_body29_hand14.urdf new file mode 100644 index 0000000..156a5fd --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/g1_body29_hand14.urdf @@ -0,0 +1,1476 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/unitree_deploy/unitree_deploy/robot_devices/assets/g1/g1_body29_hand14.xml b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/g1_body29_hand14.xml new file mode 100644 index 0000000..01fd5e2 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/assets/g1/g1_body29_hand14.xml @@ -0,0 +1,408 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/unitree_deploy/unitree_deploy/robot_devices/assets/z1/z1.urdf b/unitree_deploy/unitree_deploy/robot_devices/assets/z1/z1.urdf new file mode 100644 index 0000000..a9baebe --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/assets/z1/z1.urdf @@ -0,0 +1,261 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + transmission_interface/SimpleTransmission + + hardware_interface/EffortJointInterface + + + hardware_interface/EffortJointInterface + 1 + + + + transmission_interface/SimpleTransmission + + hardware_interface/EffortJointInterface + + + hardware_interface/EffortJointInterface + 1 + + + + transmission_interface/SimpleTransmission + + hardware_interface/EffortJointInterface + + + hardware_interface/EffortJointInterface + 1 + + + + transmission_interface/SimpleTransmission + + hardware_interface/EffortJointInterface + + + hardware_interface/EffortJointInterface + 1 + + + + transmission_interface/SimpleTransmission + + hardware_interface/EffortJointInterface + + + hardware_interface/EffortJointInterface + 1 + + + + transmission_interface/SimpleTransmission + + hardware_interface/EffortJointInterface + + + hardware_interface/EffortJointInterface + 1 + + + + diff --git a/unitree_deploy/unitree_deploy/robot_devices/cameras/configs.py b/unitree_deploy/unitree_deploy/robot_devices/cameras/configs.py new file mode 100644 index 0000000..97750ed --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/cameras/configs.py @@ -0,0 +1,125 @@ +""" +@misc{cadene2024lerobot, + author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas}, + title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in PyTorch}, + howpublished = {Available at: https://github.com/huggingface/lerobot}, + year = {2024}, +} +""" + +import abc +from dataclasses import dataclass + +import draccus + + +@dataclass +class CameraConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@CameraConfig.register_subclass("opencv") +@dataclass +class OpenCVCameraConfig(CameraConfig): + """ + Example of tested options for Intel Real Sense D405: + + ```python + OpenCVCameraConfig(0, 30, 640, 480) + OpenCVCameraConfig(0, 60, 640, 480) + OpenCVCameraConfig(0, 90, 640, 480) + OpenCVCameraConfig(0, 30, 1280, 720) + ``` + """ + + camera_index: int + fps: int | None = None + width: int | None = None + height: int | None = None + color_mode: str = "rgb" + channels: int | None = None + rotation: int | None = None + mock: bool = False + + def __post_init__(self): + if self.color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." + ) + + self.channels = 3 + + if self.rotation not in [-90, None, 90, 180]: + raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") + + +@CameraConfig.register_subclass("intelrealsense") +@dataclass +class IntelRealSenseCameraConfig(CameraConfig): + """ + Example of tested options for Intel Real Sense D405: + + ```python + IntelRealSenseCameraConfig(128422271347, 30, 640, 480) + IntelRealSenseCameraConfig(128422271347, 60, 640, 480) + IntelRealSenseCameraConfig(128422271347, 90, 640, 480) + IntelRealSenseCameraConfig(128422271347, 30, 1280, 720) + IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True) + IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90) + ``` + """ + + name: str | None = None + serial_number: int | None = None + fps: int | None = None + width: int | None = None + height: int | None = None + color_mode: str = "rgb" + channels: int | None = None + use_depth: bool = False + force_hardware_reset: bool = True + rotation: int | None = None + mock: bool = False + + def __post_init__(self): + # bool is stronger than is None, since it works with empty strings + if bool(self.name) and bool(self.serial_number): + raise ValueError( + f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided." + ) + + if self.color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." + ) + + self.channels = 3 + + at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None + at_least_one_is_none = self.fps is None or self.width is None or self.height is None + if at_least_one_is_not_none and at_least_one_is_none: + raise ValueError( + "For `fps`, `width` and `height`, either all of them need to be set, or none of them, " + f"but {self.fps=}, {self.width=}, {self.height=} were provided." + ) + + if self.rotation not in [-90, None, 90, 180]: + raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") + + +@CameraConfig.register_subclass("imageclient") +@dataclass +class ImageClientCameraConfig(CameraConfig): + head_camera_type: str + head_camera_id_numbers: list[int] + head_camera_image_shape: list[int] + + wrist_camera_type: str | None = None + wrist_camera_id_numbers: list[int] | None = None + wrist_camera_image_shape: list[int] | None = None + + aspect_ratio_threshold: float = 2.0 + fps: int = 30 + mock: bool = False diff --git a/unitree_deploy/unitree_deploy/robot_devices/cameras/imageclient.py b/unitree_deploy/unitree_deploy/robot_devices/cameras/imageclient.py new file mode 100644 index 0000000..76f31a9 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/cameras/imageclient.py @@ -0,0 +1,312 @@ +""" +This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring. +""" + +import struct +import threading +import time +from collections import deque +from multiprocessing import shared_memory + +import cv2 +import numpy as np +import zmq + +from unitree_deploy.robot_devices.cameras.configs import ImageClientCameraConfig +from unitree_deploy.robot_devices.robots_devices_utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) +from unitree_deploy.utils.rich_logger import log_error, log_info, log_success, log_warning + + +class ImageClient: + def __init__( + self, + tv_img_shape=None, + tv_img_shm_name=None, + wrist_img_shape=None, + wrist_img_shm_name=None, + image_show=False, + server_address="192.168.123.164", + port=5555, + unit_test=False, + ): + """ + tv_img_shape: User's expected head camera resolution shape (H, W, C). It should match the output of the image service terminal. + tv_img_shm_name: Shared memory is used to easily transfer images across processes to the Vuer. + wrist_img_shape: User's expected wrist camera resolution shape (H, W, C). It should maintain the same shape as tv_img_shape. + wrist_img_shm_name: Shared memory is used to easily transfer images. + image_show: Whether to display received images in real time. + server_address: The ip address to execute the image server script. + port: The port number to bind to. It should be the same as the image server. + Unit_Test: When both server and client are True, it can be used to test the image transfer latency, \ + network jitter, frame loss rate and other information. + """ + self.running = True + self._image_show = image_show + self._server_address = server_address + self._port = port + + self.tv_img_shape = tv_img_shape + self.wrist_img_shape = wrist_img_shape + + self.tv_enable_shm = False + if self.tv_img_shape is not None and tv_img_shm_name is not None: + self.tv_image_shm = shared_memory.SharedMemory(name=tv_img_shm_name) + self.tv_img_array = np.ndarray(tv_img_shape, dtype=np.uint8, buffer=self.tv_image_shm.buf) + self.tv_enable_shm = True + + self.wrist_enable_shm = False + if self.wrist_img_shape is not None and wrist_img_shm_name is not None: + self.wrist_image_shm = shared_memory.SharedMemory(name=wrist_img_shm_name) + self.wrist_img_array = np.ndarray(wrist_img_shape, dtype=np.uint8, buffer=self.wrist_image_shm.buf) + self.wrist_enable_shm = True + + # Performance evaluation parameters + self._enable_performance_eval = unit_test + if self._enable_performance_eval: + self._init_performance_metrics() + + def _init_performance_metrics(self): + self._frame_count = 0 # Total frames received + self._last_frame_id = -1 # Last received frame ID + + # Real-time FPS calculation using a time window + self._time_window = 1.0 # Time window size (in seconds) + self._frame_times = deque() # Timestamps of frames received within the time window + + # Data transmission quality metrics + self._latencies = deque() # Latencies of frames within the time window + self._lost_frames = 0 # Total lost frames + self._total_frames = 0 # Expected total frames based on frame IDs + + def _update_performance_metrics(self, timestamp, frame_id, receive_time): + # Update latency + latency = receive_time - timestamp + self._latencies.append(latency) + + # Remove latencies outside the time window + while self._latencies and self._frame_times and self._latencies[0] < receive_time - self._time_window: + self._latencies.popleft() + + # Update frame times + self._frame_times.append(receive_time) + # Remove timestamps outside the time window + while self._frame_times and self._frame_times[0] < receive_time - self._time_window: + self._frame_times.popleft() + + # Update frame counts for lost frame calculation + expected_frame_id = self._last_frame_id + 1 if self._last_frame_id != -1 else frame_id + if frame_id != expected_frame_id: + lost = frame_id - expected_frame_id + if lost < 0: + log_info(f"[Image Client] Received out-of-order frame ID: {frame_id}") + else: + self._lost_frames += lost + log_info( + f"[Image Client] Detected lost frames: {lost}, Expected frame ID: {expected_frame_id}, Received frame ID: {frame_id}" + ) + self._last_frame_id = frame_id + self._total_frames = frame_id + 1 + + self._frame_count += 1 + + def _print_performance_metrics(self, receive_time): + if self._frame_count % 30 == 0: + # Calculate real-time FPS + real_time_fps = len(self._frame_times) / self._time_window if self._time_window > 0 else 0 + + # Calculate latency metrics + if self._latencies: + avg_latency = sum(self._latencies) / len(self._latencies) + max_latency = max(self._latencies) + min_latency = min(self._latencies) + jitter = max_latency - min_latency + else: + avg_latency = max_latency = min_latency = jitter = 0 + + # Calculate lost frame rate + lost_frame_rate = (self._lost_frames / self._total_frames) * 100 if self._total_frames > 0 else 0 + + log_info( + f"[Image Client] Real-time FPS: {real_time_fps:.2f}, Avg Latency: {avg_latency * 1000:.2f} ms, Max Latency: {max_latency * 1000:.2f} ms, \ + Min Latency: {min_latency * 1000:.2f} ms, Jitter: {jitter * 1000:.2f} ms, Lost Frame Rate: {lost_frame_rate:.2f}%" + ) + + def _close(self): + self._socket.close() + self._context.term() + if self._image_show: + cv2.destroyAllWindows() + log_success("Image client has been closed.") + + def receive_process(self): + # Set up ZeroMQ context and socket + self._context = zmq.Context() + self._socket = self._context.socket(zmq.SUB) + self._socket.connect(f"tcp://{self._server_address}:{self._port}") + self._socket.setsockopt_string(zmq.SUBSCRIBE, "") + + log_warning("\nImage client has started, waiting to receive data...") + try: + while self.running: + # Receive message + message = self._socket.recv() + receive_time = time.time() + + if self._enable_performance_eval: + header_size = struct.calcsize("dI") + try: + # Attempt to extract header and image data + header = message[:header_size] + jpg_bytes = message[header_size:] + timestamp, frame_id = struct.unpack("dI", header) + except struct.error as e: + log_error(f"[Image Client] Error unpacking header: {e}, discarding message.") + continue + else: + # No header, entire message is image data + jpg_bytes = message + # Decode image + np_img = np.frombuffer(jpg_bytes, dtype=np.uint8) + current_image = cv2.imdecode(np_img, cv2.IMREAD_COLOR) + if current_image is None: + log_error("[Image Client] Failed to decode image.") + continue + + if self.tv_enable_shm: + np.copyto(self.tv_img_array, np.array(current_image[:, : self.tv_img_shape[1]])) + + if self.wrist_enable_shm: + np.copyto(self.wrist_img_array, np.array(current_image[:, -self.wrist_img_shape[1] :])) + + if self._image_show: + height, width = current_image.shape[:2] + resized_image = cv2.resize(current_image, (width // 2, height // 2)) + cv2.imshow("Image Client Stream", resized_image) + if cv2.waitKey(1) & 0xFF == ord("q"): + self.running = False + + if self._enable_performance_eval: + self._update_performance_metrics(timestamp, frame_id, receive_time) + self._print_performance_metrics(receive_time) + + except KeyboardInterrupt: + log_error("Image client interrupted by user.") + except Exception as e: + log_error(f"[Image Client] An error occurred while receiving data: {e}") + finally: + self._close() + + +class ImageClientCamera: + def __init__(self, config: ImageClientCameraConfig): + self.config = config + self.fps = config.fps + self.head_camera_type = config.head_camera_type + self.head_camera_image_shape = config.head_camera_image_shape + self.head_camera_id_numbers = config.head_camera_id_numbers + self.wrist_camera_type = config.wrist_camera_type + self.wrist_camera_image_shape = config.wrist_camera_image_shape + self.wrist_camera_id_numbers = config.wrist_camera_id_numbers + self.aspect_ratio_threshold = config.aspect_ratio_threshold + self.mock = config.mock + + self.is_binocular = ( + len(self.head_camera_id_numbers) > 1 + or self.head_camera_image_shape[1] / self.head_camera_image_shape[0] > self.aspect_ratio_threshold + ) # self.is_binocular + + self.has_wrist_camera = self.wrist_camera_type is not None # self.has_wrist_camera + + self.tv_img_shape = ( + (self.head_camera_image_shape[0], self.head_camera_image_shape[1] * 2, 3) + if self.is_binocular + and not (self.head_camera_image_shape[1] / self.head_camera_image_shape[0] > self.aspect_ratio_threshold) + else (self.head_camera_image_shape[0], self.head_camera_image_shape[1], 3) + ) + + self.tv_img_shm = shared_memory.SharedMemory(create=True, size=np.prod(self.tv_img_shape) * np.uint8().itemsize) + self.tv_img_array = np.ndarray(self.tv_img_shape, dtype=np.uint8, buffer=self.tv_img_shm.buf) + self.wrist_img_shape = None + self.wrist_img_shm = None + + if self.has_wrist_camera: + self.wrist_img_shape = (self.wrist_camera_image_shape[0], self.wrist_camera_image_shape[1] * 2, 3) + self.wrist_img_shm = shared_memory.SharedMemory( + create=True, size=np.prod(self.wrist_img_shape) * np.uint8().itemsize + ) + self.wrist_img_array = np.ndarray(self.wrist_img_shape, dtype=np.uint8, buffer=self.wrist_img_shm.buf) + self.img_shm_name = self.tv_img_shm.name + self.is_connected = False + + def connect(self): + try: + if self.is_connected: + raise RobotDeviceAlreadyConnectedError(f"ImageClient({self.camera_index}) is already connected.") + + self.img_client = ImageClient( + tv_img_shape=self.tv_img_shape, + tv_img_shm_name=self.tv_img_shm.name, + wrist_img_shape=self.wrist_img_shape, + wrist_img_shm_name=self.wrist_img_shm.name if self.wrist_img_shm else None, + ) + + image_receive_thread = threading.Thread(target=self.img_client.receive_process, daemon=True) + image_receive_thread.daemon = True + image_receive_thread.start() + + self.is_connected = True + + except Exception as e: + self.disconnect() + log_error(f"❌ Error in ImageClientCamera.connect: {e}") + + def read(self) -> np.ndarray: + pass + + def async_read(self): + try: + if not self.is_connected: + raise RobotDeviceNotConnectedError( + "ImageClient is not connected. Try running `camera.connect()` first." + ) + current_tv_image = self.tv_img_array.copy() + current_wrist_image = self.wrist_img_array.copy() if self.has_wrist_camera else None + + colors = {} + if self.is_binocular: + colors["cam_left_high"] = current_tv_image[:, : self.tv_img_shape[1] // 2] + colors["cam_right_high"] = current_tv_image[:, self.tv_img_shape[1] // 2 :] + if self.has_wrist_camera: + colors["cam_left_wrist"] = current_wrist_image[:, : self.wrist_img_shape[1] // 2] + colors["cam_right_wrist"] = current_wrist_image[:, self.wrist_img_shape[1] // 2 :] + else: + colors["cam_high"] = current_tv_image + if self.has_wrist_camera: + colors["cam_left_wrist"] = current_wrist_image[:, : self.wrist_img_shape[1] // 2] + colors["cam_right_wrist"] = current_wrist_image[:, self.wrist_img_shape[1] // 2 :] + + return colors + + except Exception as e: + self.disconnect() + log_error(f"❌ Error in ImageClientCamera.async_read: {e}") + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"ImageClient({self.camera_index}) is not connected. Try running `camera.connect()` first." + ) + + self.tv_img_shm.unlink() + self.tv_img_shm.close() + if self.has_wrist_camera: + self.wrist_img_shm.unlink() + self.wrist_img_shm.close() + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() diff --git a/unitree_deploy/unitree_deploy/robot_devices/cameras/intelrealsense.py b/unitree_deploy/unitree_deploy/robot_devices/cameras/intelrealsense.py new file mode 100644 index 0000000..b552902 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/cameras/intelrealsense.py @@ -0,0 +1,504 @@ +""" +@misc{cadene2024lerobot, + author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas}, + title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in PyTorch}, + howpublished = {Available at: https://github.com/huggingface/lerobot}, + year = {2024}, +} +This file contains utilities for recording frames from Intel Realsense cameras. +""" + +import argparse +import concurrent.futures +import logging +import math +import shutil +import threading +import time +import traceback +from collections import Counter +from pathlib import Path +from threading import Thread + +import cv2 +import numpy as np +import pyrealsense2 as rs +from PIL import Image + +from unitree_deploy.robot_devices.cameras.configs import IntelRealSenseCameraConfig +from unitree_deploy.robot_devices.robots_devices_utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, + busy_wait, + capture_timestamp_utc, +) + +SERIAL_NUMBER_INDEX = 1 + + +def find_cameras(raise_when_empty=True, mock=False) -> list[dict]: + """ + Find the names and the serial numbers of the Intel RealSense cameras + connected to the computer. + """ + cameras = [] + for device in rs.context().query_devices(): + serial_number = device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX)) + name = device.get_info(rs.camera_info.name) + cameras.append( + { + "serial_number": serial_number, + "name": name, + } + ) + + if raise_when_empty and len(cameras) == 0: + raise OSError( + "Not a single camera was detected. Try re-plugging, or re-installing `librealsense` and its python wrapper `pyrealsense2`, or updating the firmware." + ) + + return cameras + + +def save_image(img_array, serial_number, frame_index, images_dir): + try: + img = Image.fromarray(img_array) + path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png" + path.parent.mkdir(parents=True, exist_ok=True) + img.save(str(path), quality=100) + logging.info(f"Saved image: {path}") + except Exception as e: + logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}") + + +def save_images_from_cameras( + images_dir: Path, + serial_numbers: list[int] | None = None, + fps=None, + width=None, + height=None, + record_time_s=2, + mock=False, +): + """ + Initializes all the cameras and saves images to the directory. Useful to visually identify the camera + associated to a given serial number. + """ + if serial_numbers is None or len(serial_numbers) == 0: + camera_infos = find_cameras(mock=mock) + serial_numbers = [cam["serial_number"] for cam in camera_infos] + + print("Connecting cameras") + cameras = [] + for cam_sn in serial_numbers: + print(f"{cam_sn=}") + config = IntelRealSenseCameraConfig( + serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock + ) + camera = IntelRealSenseCamera(config) + camera.connect() + print( + f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})" + ) + cameras.append(camera) + + images_dir = Path(images_dir) + if images_dir.exists(): + shutil.rmtree( + images_dir, + ) + images_dir.mkdir(parents=True, exist_ok=True) + + print(f"Saving images to {images_dir}") + frame_index = 0 + start_time = time.perf_counter() + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + while True: + now = time.perf_counter() + + for camera in cameras: + # If we use async_read when fps is None, the loop will go full speed, and we will end up + # saving the same images from the cameras multiple times until the RAM/disk is full. + image = camera.read() if fps is None else camera.async_read() + if image is None: + print("No Frame") + + bgr_converted_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + executor.submit( + save_image, + bgr_converted_image, + camera.serial_number, + frame_index, + images_dir, + ) + + if fps is not None: + dt_s = time.perf_counter() - now + busy_wait(1 / fps - dt_s) + + if time.perf_counter() - start_time > record_time_s: + break + + print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") + + frame_index += 1 + finally: + print(f"Images have been saved to {images_dir}") + for camera in cameras: + camera.disconnect() + + +class IntelRealSenseCamera: + """ + The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras: + - is instantiated with the serial number of the camera - won't randomly change as it can be the case of OpenCVCamera for Linux, + - can also be instantiated with the camera's name — if it's unique — using IntelRealSenseCamera.init_from_name(), + - depth map can be returned. + + To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera: + ```bash + python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras + ``` + + When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode + of the given camera will be used. + + Example of instantiating with a serial number: + ```python + from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig + + config = IntelRealSenseCameraConfig(serial_number=128422271347) + camera = IntelRealSenseCamera(config) + camera.connect() + color_image = camera.read() + # when done using the camera, consider disconnecting + camera.disconnect() + ``` + + Example of instantiating with a name if it's unique: + ``` + config = IntelRealSenseCameraConfig(name="Intel RealSense D405") + ``` + + Example of changing default fps, width, height and color_mode: + ```python + config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720) + config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480) + config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr") + # Note: might error out upon `camera.connect()` if these settings are not compatible with the camera + ``` + + Example of returning depth: + ```python + config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True) + camera = IntelRealSenseCamera(config) + camera.connect() + color_image, depth_map = camera.read() + ``` + """ + + def __init__( + self, + config: IntelRealSenseCameraConfig, + ): + self.config = config + if config.name is not None: + self.serial_number = self.find_serial_number_from_name(config.name) + else: + self.serial_number = config.serial_number + self.fps = config.fps + self.width = config.width + self.height = config.height + self.channels = config.channels + self.color_mode = config.color_mode + self.use_depth = config.use_depth + self.force_hardware_reset = config.force_hardware_reset + self.mock = config.mock + + self.camera = None + self.is_connected = False + self.thread = None + self.stop_event = None + self.color_image = None + self.depth_map = None + self.logs = {} + + if self.mock: + import tests.mock_cv2 as cv2 + else: + import cv2 + + self.rotation = None + if config.rotation == -90: + self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE + elif config.rotation == 90: + self.rotation = cv2.ROTATE_90_CLOCKWISE + elif config.rotation == 180: + self.rotation = cv2.ROTATE_180 + + def find_serial_number_from_name(self, name): + camera_infos = find_cameras() + camera_names = [cam["name"] for cam in camera_infos] + this_name_count = Counter(camera_names)[name] + if this_name_count > 1: + raise ValueError( + f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them." + ) + + name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos} + cam_sn = name_to_serial_dict[name] + + return cam_sn + + def connect(self): + if self.is_connected: + raise RobotDeviceAlreadyConnectedError( + f"IntelRealSenseCamera({self.serial_number}) is already connected." + ) + + if self.mock: + import tests.mock_pyrealsense2 as rs + else: + import pyrealsense2 as rs + + config = rs.config() + config.enable_device(str(self.serial_number)) + + if self.fps and self.width and self.height: + config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps) + else: + config.enable_stream(rs.stream.color) + + if self.use_depth: + if self.fps and self.width and self.height: + config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps) + else: + config.enable_stream(rs.stream.depth) + + self.camera = rs.pipeline() + try: + profile = self.camera.start(config) + is_camera_open = True + except RuntimeError: + is_camera_open = False + traceback.print_exc() + + # If the camera doesn't work, display the camera indices corresponding to + # valid cameras. + if not is_camera_open: + # Verify that the provided `serial_number` is valid before printing the traceback + camera_infos = find_cameras() + serial_numbers = [cam["serial_number"] for cam in camera_infos] + if self.serial_number not in serial_numbers: + raise ValueError( + f"`serial_number` is expected to be one of these available cameras {serial_numbers}, but {self.serial_number} is provided instead. " + "To find the serial number you should use, run `python lerobot/common/robot_devices/cameras/intelrealsense.py`." + ) + + raise OSError(f"Can't access IntelRealSenseCamera({self.serial_number}).") + + color_stream = profile.get_stream(rs.stream.color) + color_profile = color_stream.as_video_stream_profile() + actual_fps = color_profile.fps() + actual_width = color_profile.width() + actual_height = color_profile.height() + + # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) + if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + # Using `OSError` since it's a broad that encompasses issues related to device communication + raise OSError( + f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}." + ) + if self.width is not None and self.width != actual_width: + raise OSError( + f"Can't set {self.width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}." + ) + if self.height is not None and self.height != actual_height: + raise OSError( + f"Can't set {self.height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}." + ) + + self.fps = round(actual_fps) + self.width = round(actual_width) + self.height = round(actual_height) + + self.is_connected = True + + def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3) + of type `np.uint8`, contrarily to the pytorch format which is float channel first. + + When `use_depth=True`, returns a tuple `(color_image, depth_map)` with a depth map in the format + height x width (e.g. 480 x 640) of type np.uint16. + + Note: Reading a frame is done every `camera.fps` times per second, and it is blocking. + If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." + ) + + if self.mock: + import tests.mock_cv2 as cv2 + else: + import cv2 + + start_time = time.perf_counter() + + frame = self.camera.wait_for_frames(timeout_ms=5000) + + color_frame = frame.get_color_frame() + + if not color_frame: + raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).") + + color_image = np.asanyarray(color_frame.get_data()) + + requested_color_mode = self.color_mode if temporary_color is None else temporary_color + if requested_color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." + ) + + # IntelRealSense uses RGB format as default (red, green, blue). + if requested_color_mode == "bgr": + color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR) + + h, w, _ = color_image.shape + if h != self.height or w != self.width: + raise OSError( + f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." + ) + + if self.rotation is not None: + color_image = cv2.rotate(color_image, self.rotation) + + # log the number of seconds it took to read the image + self.logs["delta_timestamp_s"] = time.perf_counter() - start_time + + # log the utc time at which the image was received + self.logs["timestamp_utc"] = capture_timestamp_utc() + + if self.use_depth: + depth_frame = frame.get_depth_frame() + if not depth_frame: + raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).") + + depth_map = np.asanyarray(depth_frame.get_data()) + + h, w = depth_map.shape + if h != self.height or w != self.width: + raise OSError( + f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." + ) + + if self.rotation is not None: + depth_map = cv2.rotate(depth_map, self.rotation) + + return color_image, depth_map + else: + return color_image + + def read_loop(self): + while not self.stop_event.is_set(): + if self.use_depth: + self.color_image, self.depth_map = self.read() + else: + self.color_image = self.read() + + def async_read(self): + """Access the latest color image""" + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." + ) + + if self.thread is None: + self.stop_event = threading.Event() + self.thread = Thread(target=self.read_loop, args=()) + self.thread.daemon = True + self.thread.start() + + num_tries = 0 + while self.color_image is None: + num_tries += 1 + time.sleep(1 / self.fps) + if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()): + raise Exception( + "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." + ) + + if self.use_depth: + return self.color_image, self.depth_map + else: + return self.color_image + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." + ) + + if self.thread is not None and self.thread.is_alive(): + # wait for the thread to finish + self.stop_event.set() + self.thread.join() + self.thread = None + self.stop_event = None + + self.camera.stop() + self.camera = None + + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Save a few frames using `IntelRealSenseCamera` for all cameras connected to the computer, or a selected subset." + ) + parser.add_argument( + "--serial-numbers", + type=int, + nargs="*", + default=None, + help="List of serial numbers used to instantiate the `IntelRealSenseCamera`. If not provided, find and use all available camera indices.", + ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.", + ) + parser.add_argument( + "--width", + type=str, + default=640, + help="Set the width for all cameras. If not provided, use the default width of each camera.", + ) + parser.add_argument( + "--height", + type=str, + default=480, + help="Set the height for all cameras. If not provided, use the default height of each camera.", + ) + parser.add_argument( + "--images-dir", + type=Path, + default="outputs/images_from_intelrealsense_cameras", + help="Set directory to save a few frames for each camera.", + ) + parser.add_argument( + "--record-time-s", + type=float, + default=2.0, + help="Set the number of seconds used to record the frames. By default, 2 seconds.", + ) + args = parser.parse_args() + save_images_from_cameras(**vars(args)) diff --git a/unitree_deploy/unitree_deploy/robot_devices/cameras/opencv.py b/unitree_deploy/unitree_deploy/robot_devices/cameras/opencv.py new file mode 100644 index 0000000..837a789 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/cameras/opencv.py @@ -0,0 +1,483 @@ +""" +@misc{cadene2024lerobot, + author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas}, + title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in PyTorch}, + howpublished = {Available at: https://github.com/huggingface/lerobot}, + year = {2024}, +} +This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring. +""" + +import argparse +import concurrent.futures +import math +import platform +import shutil +import threading +import time +from pathlib import Path +from threading import Thread + +import cv2 +import numpy as np +from PIL import Image + +from unitree_deploy.robot_devices.cameras.configs import OpenCVCameraConfig +from unitree_deploy.robot_devices.robots_devices_utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, + busy_wait, + capture_timestamp_utc, +) + +# The maximum opencv device index depends on your operating system. For instance, +# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case +# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23. +# When you change the USB port or reboot the computer, the operating system might +# treat the same cameras as new devices. Thus we select a higher bound to search indices. +MAX_OPENCV_INDEX = 60 + + +def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: + cameras = [] + if platform.system() == "Linux": + print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports") + possible_ports = [str(port) for port in Path("/dev").glob("video*")] + ports = _find_cameras(possible_ports, mock=mock) + for port in ports: + cameras.append( + { + "port": port, + "index": int(port.removeprefix("/dev/video")), + } + ) + else: + print( + "Mac or Windows detected. Finding available camera indices through " + f"scanning all indices from 0 to {MAX_OPENCV_INDEX}" + ) + possible_indices = range(max_index_search_range) + indices = _find_cameras(possible_indices, mock=mock) + for index in indices: + cameras.append( + { + "port": None, + "index": index, + } + ) + + return cameras + + +def _find_cameras( + possible_camera_ids: list[int | str], raise_when_empty=False, mock=False +) -> list[int | str]: + camera_ids = [] + for camera_idx in possible_camera_ids: + camera = cv2.VideoCapture(camera_idx) + is_open = camera.isOpened() + camera.release() + + if is_open: + print(f"Camera found at index {camera_idx}") + camera_ids.append(camera_idx) + + if raise_when_empty and len(camera_ids) == 0: + raise OSError( + "Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, " + "or your camera driver, or make sure your camera is compatible with opencv2." + ) + + return camera_ids + + +def is_valid_unix_path(path: str) -> bool: + """Note: if 'path' points to a symlink, this will return True only if the target exists""" + p = Path(path) + return p.is_absolute() and p.exists() + + +def get_camera_index_from_unix_port(port: Path) -> int: + return int(str(port.resolve()).removeprefix("/dev/video")) + + +def save_image(img_array, camera_index, frame_index, images_dir): + img = Image.fromarray(img_array) + path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png" + path.parent.mkdir(parents=True, exist_ok=True) + img.save(str(path), quality=100) + + +def save_images_from_cameras( + images_dir: Path, + camera_ids: list | None = None, + fps=None, + width=None, + height=None, + record_time_s=2, + mock=False, +): + """ + Initializes all the cameras and saves images to the directory. Useful to visually identify the camera + associated to a given camera index. + """ + if camera_ids is None or len(camera_ids) == 0: + camera_infos = find_cameras(mock=mock) + camera_ids = [cam["index"] for cam in camera_infos] + + print("Connecting cameras") + cameras = [] + for cam_idx in camera_ids: + config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock) + camera = OpenCVCamera(config) + camera.connect() + print( + f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, " + f"height={camera.height}, color_mode={camera.color_mode})" + ) + cameras.append(camera) + + images_dir = Path(images_dir) + if images_dir.exists(): + shutil.rmtree( + images_dir, + ) + images_dir.mkdir(parents=True, exist_ok=True) + + print(f"Saving images to {images_dir}") + frame_index = 0 + start_time = time.perf_counter() + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + while True: + now = time.perf_counter() + + for camera in cameras: + # If we use async_read when fps is None, the loop will go full speed, and we will endup + # saving the same images from the cameras multiple times until the RAM/disk is full. + image = camera.read() if fps is None else camera.async_read() + + executor.submit( + save_image, + image, + camera.camera_index, + frame_index, + images_dir, + ) + + if fps is not None: + dt_s = time.perf_counter() - now + busy_wait(1 / fps - dt_s) + + print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") + + if time.perf_counter() - start_time > record_time_s: + break + + frame_index += 1 + + print(f"Images have been saved to {images_dir}") + + +class OpenCVCamera: + """ + The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate + with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). + + An OpenCVCamera instance requires a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera + like a webcam of a laptop, the camera index is expected to be 0, but it might also be very different, and the camera index + might change if you reboot your computer or re-plug your camera. This behavior depends on your operation system. + + To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera: + ```bash + python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras + ``` + + When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode + of the given camera will be used. + + Example of usage: + ```python + from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig + + config = OpenCVCameraConfig(camera_index=0) + camera = OpenCVCamera(config) + camera.connect() + color_image = camera.read() + # when done using the camera, consider disconnecting + camera.disconnect() + ``` + + Example of changing default fps, width, height and color_mode: + ```python + config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720) + config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480) + config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr") + # Note: might error out open `camera.connect()` if these settings are not compatible with the camera + ``` + """ + + def __init__(self, config: OpenCVCameraConfig): + self.config = config + self.camera_index = config.camera_index + self.port = None + + # Linux uses ports for connecting to cameras + if platform.system() == "Linux": + if isinstance(self.camera_index, int): + self.port = Path(f"/dev/video{self.camera_index}") + elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index): + self.port = Path(self.camera_index) + # Retrieve the camera index from a potentially symlinked path + self.camera_index = get_camera_index_from_unix_port(self.port) + else: + raise ValueError(f"Please check the provided camera_index: {self.camera_index}") + + self.fps = config.fps + self.width = config.width + self.height = config.height + self.channels = config.channels + self.color_mode = config.color_mode + self.mock = config.mock + + self.camera = None + self.is_connected = False + self.thread = None + self.stop_event = None + self.color_image = None + self.logs = {} + + if self.mock: + import tests.mock_cv2 as cv2 + else: + import cv2 + + self.rotation = None + if config.rotation == -90: + self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE + elif config.rotation == 90: + self.rotation = cv2.ROTATE_90_CLOCKWISE + elif config.rotation == 180: + self.rotation = cv2.ROTATE_180 + + def connect(self): + if self.is_connected: + raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") + + if self.mock: + import tests.mock_cv2 as cv2 + else: + import cv2 + + # Use 1 thread to avoid blocking the main thread. Especially useful during data collection + # when other threads are used to save the images. + cv2.setNumThreads(1) + + camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index + # First create a temporary camera trying to access `camera_index`, + # and verify it is a valid camera by calling `isOpened`. + tmp_camera = cv2.VideoCapture(camera_idx) + is_camera_open = tmp_camera.isOpened() + # Release camera to make it accessible for `find_camera_indices` + tmp_camera.release() + del tmp_camera + + # If the camera doesn't work, display the camera indices corresponding to + # valid cameras. + if not is_camera_open: + # Verify that the provided `camera_index` is valid before printing the traceback + cameras_info = find_cameras() + available_cam_ids = [cam["index"] for cam in cameras_info] + if self.camera_index not in available_cam_ids: + raise ValueError( + f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. " + "To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`." + ) + + raise OSError(f"Can't access OpenCVCamera({camera_idx}).") + + # Secondly, create the camera that will be used downstream. + # Note: For some unknown reason, calling `isOpened` blocks the camera which then + # needs to be re-created. + self.camera = cv2.VideoCapture(camera_idx) + + if self.fps is not None: + self.camera.set(cv2.CAP_PROP_FPS, self.fps) + if self.width is not None: + self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width) + if self.height is not None: + self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height) + + actual_fps = self.camera.get(cv2.CAP_PROP_FPS) + actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH) + actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT) + + # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) + if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + # Using `OSError` since it's a broad that encompasses issues related to device communication + raise OSError( + f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}." + ) + if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3): + raise OSError( + f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}." + ) + if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3): + raise OSError( + f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}." + ) + + self.fps = round(actual_fps) + self.width = round(actual_width) + self.height = round(actual_height) + + self.is_connected = True + + def read(self, temporary_color_mode: str | None = None) -> np.ndarray: + """Read a frame from the camera returned in the format (height, width, channels) + (e.g. 480 x 640 x 3), contrarily to the pytorch format which is channel first. + + Note: Reading a frame is done every `camera.fps` times per second, and it is blocking. + If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`. + """ + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." + ) + + start_time = time.perf_counter() + + ret, color_image = self.camera.read() + + if not ret: + raise OSError(f"Can't capture color image from camera {self.camera_index}.") + + requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode + + if requested_color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." + ) + + # OpenCV uses BGR format as default (blue, green, red) for all operations, including displaying images. + # However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks, + # so we convert the image color from BGR to RGB. + if requested_color_mode == "rgb": + if self.mock: + import tests.mock_cv2 as cv2 + else: + import cv2 + + color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB) + + h, w, _ = color_image.shape + if h != self.height or w != self.width: + raise OSError( + f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." + ) + + if self.rotation is not None: + color_image = cv2.rotate(color_image, self.rotation) + + # log the number of seconds it took to read the image + self.logs["delta_timestamp_s"] = time.perf_counter() - start_time + + # log the utc time at which the image was received + self.logs["timestamp_utc"] = capture_timestamp_utc() + + self.color_image = color_image + + return color_image + + def read_loop(self): + while not self.stop_event.is_set(): + try: + self.color_image = self.read() + except Exception as e: + print(f"Error reading in thread: {e}") + + def async_read(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." + ) + + if self.thread is None: + self.stop_event = threading.Event() + self.thread = Thread(target=self.read_loop, args=()) + self.thread.daemon = True + self.thread.start() + + num_tries = 0 + while True: + if self.color_image is not None: + return self.color_image + + time.sleep(1 / self.fps) + num_tries += 1 + if num_tries > self.fps * 2: + raise TimeoutError("Timed out waiting for async_read() to start.") + + def disconnect(self): + if not self.is_connected: + raise RobotDeviceNotConnectedError( + f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." + ) + + if self.thread is not None: + self.stop_event.set() + self.thread.join() # wait for the thread to finish + self.thread = None + self.stop_event = None + + self.camera.release() + self.camera = None + self.is_connected = False + + def __del__(self): + if getattr(self, "is_connected", False): + self.disconnect() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Save a few frames using `OpenCVCamera` for all cameras connected to the computer, or a selected subset." + ) + parser.add_argument( + "--camera-ids", + type=int, + nargs="*", + default=None, + help="List of camera indices used to instantiate the `OpenCVCamera`. If not provided, find and use all available camera indices.", + ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.", + ) + parser.add_argument( + "--width", + type=str, + default=640, + help="Set the width for all cameras. If not provided, use the default width of each camera.", + ) + parser.add_argument( + "--height", + type=str, + default=480, + help="Set the height for all cameras. If not provided, use the default height of each camera.", + ) + parser.add_argument( + "--images-dir", + type=Path, + default="outputs/images_from_opencv_cameras", + help="Set directory to save a few frames for each camera.", + ) + parser.add_argument( + "--record-time-s", + type=float, + default=4.0, + help="Set the number of seconds used to record the frames. By default, 2 seconds.", + ) + args = parser.parse_args() + save_images_from_cameras(**vars(args)) diff --git a/unitree_deploy/unitree_deploy/robot_devices/cameras/utils.py b/unitree_deploy/unitree_deploy/robot_devices/cameras/utils.py new file mode 100644 index 0000000..4eead3a --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/cameras/utils.py @@ -0,0 +1,74 @@ +""" +@misc{cadene2024lerobot, + author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas}, + title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in PyTorch}, + howpublished = {Available at: https://github.com/huggingface/lerobot}, + year = {2024}, +} +""" + +from typing import Protocol + +import numpy as np + +from unitree_deploy.robot_devices.cameras.configs import ( + CameraConfig, + ImageClientCameraConfig, + IntelRealSenseCameraConfig, + OpenCVCameraConfig, +) + + +# Defines a camera type +class Camera(Protocol): + def connect(self): ... + def read(self, temporary_color: str | None = None) -> np.ndarray: ... + def async_read(self) -> np.ndarray: ... + def disconnect(self): ... + + +def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]: + cameras = {} + + for key, cfg in camera_configs.items(): + if cfg.type == "opencv": + from unitree_deploy.robot_devices.cameras.opencv import OpenCVCamera + + cameras[key] = OpenCVCamera(cfg) + + elif cfg.type == "intelrealsense": + from unitree_deploy.robot_devices.cameras.intelrealsense import IntelRealSenseCamera + + cameras[key] = IntelRealSenseCamera(cfg) + + elif cfg.type == "imageclient": + from unitree_deploy.robot_devices.cameras.imageclient import ImageClientCamera + + cameras[key] = ImageClientCamera(cfg) + else: + raise ValueError(f"The motor type '{cfg.type}' is not valid.") + + return cameras + + +def make_camera(camera_type, **kwargs) -> Camera: + if camera_type == "opencv": + from unitree_deploy.robot_devices.cameras.opencv import OpenCVCamera + + config = OpenCVCameraConfig(**kwargs) + return OpenCVCamera(config) + + elif camera_type == "intelrealsense": + from unitree_deploy.robot_devices.cameras.intelrealsense import IntelRealSenseCamera + + config = IntelRealSenseCameraConfig(**kwargs) + return IntelRealSenseCamera(config) + + elif camera_type == "imageclient": + from unitree_deploy.robot_devices.cameras.imageclient import ImageClientCamera + + config = ImageClientCameraConfig(**kwargs) + return ImageClientCamera(config) + + else: + raise ValueError(f"The camera type '{camera_type}' is not valid.") diff --git a/unitree_deploy/unitree_deploy/robot_devices/endeffector/configs.py b/unitree_deploy/unitree_deploy/robot_devices/endeffector/configs.py new file mode 100644 index 0000000..0b51f15 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/endeffector/configs.py @@ -0,0 +1,29 @@ +import abc +from dataclasses import dataclass + +import draccus +import numpy as np + + +@dataclass +class EndEffectorConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@EndEffectorConfig.register_subclass("dex_1") +@dataclass +class Dex1_GripperConfig(EndEffectorConfig): + motors: dict[str, tuple[int, str]] + unit_test: bool = False + init_pose: list | None = None + control_dt: float = 1 / 200 + mock: bool = False + max_pos_speed: float = 180 * (np.pi / 180) * 2 + topic_gripper_command: str = "rt/unitree_actuator/cmd" + topic_gripper_state: str = "rt/unitree_actuator/state" + + def __post_init__(self): + if self.control_dt < 0.002: + raise ValueError(f"`control_dt` must > 1/500 (got {self.control_dt})") diff --git a/unitree_deploy/unitree_deploy/robot_devices/endeffector/gripper.py b/unitree_deploy/unitree_deploy/robot_devices/endeffector/gripper.py new file mode 100644 index 0000000..c8d7a4b --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/endeffector/gripper.py @@ -0,0 +1,227 @@ +# for gripper +import threading +import time + +import numpy as np +from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber # dds +from unitree_sdk2py.idl.default import unitree_go_msg_dds__MotorCmd_ +from unitree_sdk2py.idl.unitree_go.msg.dds_ import MotorCmds_, MotorStates_ # idl + +from unitree_deploy.robot_devices.arm.arm_indexs import Gripper_Sigle_JointIndex +from unitree_deploy.robot_devices.endeffector.configs import Dex1_GripperConfig +from unitree_deploy.robot_devices.robots_devices_utils import DataBuffer, MotorState, Robot_Num_Motors +from unitree_deploy.utils.joint_trajcetory_inter import JointTrajectoryInterpolator +from unitree_deploy.utils.rich_logger import log_error, log_info, log_warning + + +class Gripper_LowState: + def __init__(self): + self.motor_state = [MotorState() for _ in range(Robot_Num_Motors.Dex1_Gripper_Num_Motors)] + + +class Dex1_Gripper_Controller: + def __init__(self, config: Dex1_GripperConfig): + log_info("Initialize Dex1_Gripper_Controller...") + + self.init_pose = np.array(config.init_pose) + + self.motors = config.motors + self.mock = config.mock + self.control_dt = config.control_dt + self.unit_test = config.unit_test + self.max_pos_speed = config.max_pos_speed + + self.topic_gripper_command = config.topic_gripper_command + self.topic_gripper_state = config.topic_gripper_state + + self.q_target = np.zeros(1) + self.tauff_target = np.zeros(1) + self.time_target = time.monotonic() + self.gripper_cmd = "schedule_waypoint" + + self.lowstate_buffer = DataBuffer() + self.ctrl_lock = threading.Lock() + + self.MAX_DIST = 5.45 + self.MIN_DIST = 0.0 + self.DELTA_GRIPPER_CMD = 0.18 + + self.is_connected = False + + @property + def motor_names(self) -> list[str]: + return list(self.motors.keys()) + + @property + def motor_models(self) -> list[str]: + return [model for _, model in self.motors.values()] + + @property + def motor_indices(self) -> list[int]: + return [idx for idx, _ in self.motors.values()] + + def connect(self): + try: + if self.unit_test: + ChannelFactoryInitialize(0) + + dq = 0.0 + tau = 0.0 + kp = 10.0 + kd = 0.05 + + # initialize gripper cmd msg + self.gripper_msg = MotorCmds_() + self.gripper_msg.cmds = [unitree_go_msg_dds__MotorCmd_() for _ in range(len(Gripper_Sigle_JointIndex))] + for id in Gripper_Sigle_JointIndex: + self.gripper_msg.cmds[id].dq = dq + self.gripper_msg.cmds[id].tau_est = tau + self.gripper_msg.cmds[id].kp = kp + self.gripper_msg.cmds[id].kd = kd + + # initialize handcmd publisher and handstate subscriber + self.GripperCmb_publisher = ChannelPublisher(self.topic_gripper_command, MotorCmds_) + self.GripperCmb_publisher.Init() + + self.GripperState_subscriber = ChannelSubscriber(self.topic_gripper_state, MotorStates_) + self.GripperState_subscriber.Init() + + # initialize subscribe thread + self.subscribe_state_thread = threading.Thread(target=self._subscribe_gripper_motor_state) + self.subscribe_state_thread.daemon = True + self.subscribe_state_thread.start() + + while not self.lowstate_buffer.get_data(): + time.sleep(0.01) + log_warning("[Dex1_Gripper_Controller] Waiting to subscribe dds...") + + self.gripper_control_thread = threading.Thread(target=self._ctrl_gripper_motor) + self.gripper_control_thread.daemon = True + self.gripper_control_thread.start() + + self.is_connected = True + + except Exception as e: + self.disconnect() + log_error(f"❌ Error in Dex1_Gripper_Controller.connect: {e}") + + def _subscribe_gripper_motor_state(self): + try: + while True: + gripper_msg = self.GripperState_subscriber.Read() + if gripper_msg is not None: + lowstate = Gripper_LowState() + for idx, id in enumerate(Gripper_Sigle_JointIndex): + lowstate.motor_state[idx].q = gripper_msg.states[id].q + lowstate.motor_state[idx].dq = gripper_msg.states[id].dq + self.lowstate_buffer.set_data(lowstate) + time.sleep(0.002) + except Exception as e: + self.disconnect() + log_error(f"❌ Error in Dex1_Gripper_Controller._subscribe_gripper_motor_state: {e}") + + def _update_gripper(self, gripper_q_target: np.ndarray): + current_qs = np.array([self.lowstate_buffer.get_data().motor_state[id].q for id in Gripper_Sigle_JointIndex]) + clamped_qs = np.clip(gripper_q_target, current_qs - self.DELTA_GRIPPER_CMD, current_qs + self.DELTA_GRIPPER_CMD) + """set current left, right gripper motor state target q""" + for idx, id in enumerate(Gripper_Sigle_JointIndex): + self.gripper_msg.cmds[id].q = np.array(clamped_qs)[idx] + self.GripperCmb_publisher.Write(self.gripper_msg) + + def _drive_to_waypoint(self, target_pose: np.ndarray, t_insert_time: float): + curr_time = time.monotonic() + self.control_dt + t_insert = curr_time + t_insert_time + self.pose_interp = self.pose_interp.drive_to_waypoint( + pose=target_pose, + time=t_insert, + curr_time=curr_time, + max_pos_speed=self.max_pos_speed, + ) + + while time.monotonic() < t_insert: + self._update_gripper(self.pose_interp(time.monotonic())) + time.sleep(self.control_dt) + + def _ctrl_gripper_motor(self): + try: + self.pose_interp = JointTrajectoryInterpolator( + times=[time.monotonic()], + joint_positions=[self.read_current_endeffector_q()], + ) + + gripper_q_target = self.read_current_endeffector_q() + gripper_tauff_target = self.tauff_target + gripper_time_target = time.monotonic() + gripper_cmd = "schedule_waypoint" + + last_waypoint_time = time.monotonic() + while True: + start_time = time.perf_counter() + t_now = time.monotonic() + with self.ctrl_lock: + gripper_q_target = self.q_target + gripper_tauff_target = self.tauff_target # noqa: F841 + gripper_time_target = self.time_target + gripper_cmd = self.gripper_cmd + + if gripper_cmd is None: + self._update_gripper(gripper_q_target) + # time.sleep(max(0, (self.control_dt - (time.perf_counter() - start_time)))) + elif gripper_cmd == "drive_to_waypoint": + self._drive_to_waypoint(target_pose=gripper_q_target, t_insert_time=0.8) + + elif gripper_cmd == "schedule_waypoint": + target_time = time.monotonic() - time.perf_counter() + gripper_time_target + curr_time = t_now + self.control_dt + target_time = max(target_time, curr_time + self.control_dt) + + self.pose_interp = self.pose_interp.schedule_waypoint( + pose=gripper_q_target, + time=target_time, + max_pos_speed=self.max_pos_speed, + curr_time=curr_time, + last_waypoint_time=last_waypoint_time, + ) + last_waypoint_time = target_time + + self._update_gripper(self.pose_interp(t_now)) + time.sleep(max(0, (self.control_dt - (time.perf_counter() - start_time)))) + + except Exception as e: + self.disconnect() + log_error(f"❌ Error in Dex1_Gripper_Controller._ctrl_gripper_motor: {e}") + + def read_current_endeffector_q(self) -> np.ndarray: + # Motor inversion left is 1 and right is 0 TODO(gh): Correct this + motor_states = np.array([self.lowstate_buffer.get_data().motor_state[id].q for id in Gripper_Sigle_JointIndex]) + return np.array(motor_states) + + def read_current_endeffector_dq(self) -> np.ndarray: + # Motor inversion left is 1 and right is 0 TODO(gh): Correct this + motor_states_dq = np.array( + [self.lowstate_buffer.get_data().motor_state[id].dq for id in Gripper_Sigle_JointIndex] + ) + return np.array(motor_states_dq) + + def write_endeffector( + self, + q_target: list[float] | np.ndarray, + tauff_target: list[float] | np.ndarray = None, + time_target: float | None = None, + cmd_target: str | None = None, + ): + with self.ctrl_lock: + self.q_target = q_target + self.tauff_target = tauff_target + self.time_target = time_target + self.gripper_cmd = cmd_target + + def go_start(self): + self._drive_to_waypoint(target_pose=self.init_pose, t_insert_time=0.8) + + def go_home(self): + self._drive_to_waypoint(target_pose=self.init_pose, t_insert_time=0.8) + + def disconnect(self): + self.is_connected = False + # self.go_home() diff --git a/unitree_deploy/unitree_deploy/robot_devices/endeffector/utils.py b/unitree_deploy/unitree_deploy/robot_devices/endeffector/utils.py new file mode 100644 index 0000000..fa6f467 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/endeffector/utils.py @@ -0,0 +1,50 @@ +from typing import Protocol + +from unitree_deploy.robot_devices.endeffector.configs import ( + Dex1_GripperConfig, + EndEffectorConfig, +) + + +class EndEffector(Protocol): + def connect(self): ... + def disconnect(self): ... + def motor_names(self): ... + + def read_current_endeffector_q(self): ... + def read_current_endeffector_dq(self): ... + def write_endeffector(self): ... + + def retarget_to_endeffector(self): ... + def endeffector_ik(self): ... + + def go_start(self): ... + def go_home(self): ... + + +def make_endeffector_motors_buses_from_configs( + endeffector_configs: dict[str, EndEffectorConfig], +) -> list[EndEffectorConfig]: + endeffector_motors_buses = {} + + for key, cfg in endeffector_configs.items(): + if cfg.type == "dex_1": + from unitree_deploy.robot_devices.endeffector.gripper import Dex1_Gripper_Controller + + endeffector_motors_buses[key] = Dex1_Gripper_Controller(cfg) + + else: + raise ValueError(f"The motor type '{cfg.type}' is not valid.") + + return endeffector_motors_buses + + +def make_endeffector_motors_bus(endeffector_type: str, **kwargs) -> EndEffectorConfig: + if endeffector_type == "dex_1": + from unitree_deploy.robot_devices.endeffector.gripper import Dex1_Gripper_Controller + + config = Dex1_GripperConfig(**kwargs) + return Dex1_Gripper_Controller(config) + + else: + raise ValueError(f"The motor type '{endeffector_type}' is not valid.") diff --git a/unitree_deploy/unitree_deploy/robot_devices/robots_devices_utils.py b/unitree_deploy/unitree_deploy/robot_devices/robots_devices_utils.py new file mode 100644 index 0000000..fcac5e4 --- /dev/null +++ b/unitree_deploy/unitree_deploy/robot_devices/robots_devices_utils.py @@ -0,0 +1,87 @@ +import platform +import threading +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import IntEnum +from typing import Optional + + +class Robot_Num_Motors(IntEnum): + Z1_6_Num_Motors = 6 + Z1_7_Num_Motors = 7 + Z1_12_Num_Motors = 12 + + Dex1_Gripper_Num_Motors = 2 + G1_29_Num_Motors = 35 + + +@dataclass +class MotorState: + q: Optional[float] = None + dq: Optional[float] = None + tau: Optional[float] = None + + +class DataBuffer: + def __init__(self) -> None: + self.data = None + self.lock = threading.Lock() + + def get_data(self): + with self.lock: + return self.data + + def set_data(self, data) -> None: + with self.lock: + self.data = data + + +class RobotDeviceNotConnectedError(Exception): + """Exception raised when the robot device is not connected.""" + + def __init__( + self, message="This robot device is not connected. Try calling `robot_device.connect()` first." + ): + self.message = message + super().__init__(self.message) + + +class RobotDeviceAlreadyConnectedError(Exception): + """Exception raised when the robot device is already connected.""" + + def __init__( + self, + message="This robot device is already connected. Try not calling `robot_device.connect()` twice.", + ): + self.message = message + super().__init__(self.message) + + +def capture_timestamp_utc(): + return datetime.now(timezone.utc) + + +def busy_wait(seconds): + if platform.system() == "Darwin": + # On Mac, `time.sleep` is not accurate and we need to use this while loop trick, + # but it consumes CPU cycles. + end_time = time.perf_counter() + seconds + while time.perf_counter() < end_time: + pass + else: + # On Linux time.sleep is accurate + if seconds > 0: + time.sleep(seconds) + + +def precise_wait(t_end: float, slack_time: float = 0.001, time_func=time.monotonic): + t_start = time_func() + t_wait = t_end - t_start + if t_wait > 0: + t_sleep = t_wait - slack_time + if t_sleep > 0: + time.sleep(t_sleep) + while time_func() < t_end: + pass + return diff --git a/unitree_deploy/unitree_deploy/utils/eval_utils.py b/unitree_deploy/unitree_deploy/utils/eval_utils.py new file mode 100644 index 0000000..97afea7 --- /dev/null +++ b/unitree_deploy/unitree_deploy/utils/eval_utils.py @@ -0,0 +1,281 @@ +import logging +import sys +import time +import traceback +import warnings +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar + +import cv2 +import numpy as np +import pandas as pd +import pyarrow as pa +import requests +import torch +import torchvision +from datasets import load_from_disk +from datasets.features.features import register_feature +from safetensors.torch import load_file + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + + +class LongConnectionClient: + def __init__(self, base_url): + self.session = requests.Session() + self.base_url = base_url + + def send_post(self, endpoint, json_data): + """send POST request to endpoint""" + url = f"{self.base_url}{endpoint}" + response = None + while True: + try: + response = self.session.post(url, json=json_data) + if response.status_code == 200: + data = response.json() + if data["result"] == "ok": + response = data + break + else: + logging.info(data["desc"]) + + time.sleep(1) + except Exception as e: + logging.error(f"An error occurred: {e}") + logging.error(traceback.format_exc()) + + return response + + def close(self): + """ "close session""" + self.session.close() + + def predict_action(self, language_instruction, batch) -> torch.Tensor: + # collect data + data = { + "language_instruction": language_instruction, + "observation.state": torch.stack(list(batch["observation.state"])).tolist(), + "observation.images.top": torch.stack(list(batch["observation.images.top"])).tolist(), + "action": torch.stack(list(batch["action"])).tolist(), + } + + # send data + endpoint = "/predict_action" + response = self.send_post(endpoint, data) + # action = torch.tensor(response['action']).unsqueeze(0) + action = torch.tensor(response["action"]) + return action + + +class ACTTemporalEnsembler: + def __init__(self, temporal_ensemble_coeff: float, chunk_size: int, exe_steps: int) -> None: + """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. + + The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action. + They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the + coefficient works: + - Setting it to 0 uniformly weighs all actions. + - Setting it positive gives more weight to older actions. + - Setting it negative gives more weight to newer actions. + NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This + results in older actions being weighed more highly than newer actions (the experiments documented in + https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be + detrimental: doing so aggressively may diminish the benefits of action chunking). + + Here we use an online method for computing the average rather than caching a history of actions in + order to compute the average offline. For a simple 1D sequence it looks something like: + + ``` + import torch + + seq = torch.linspace(8, 8.5, 100) + print(seq) + + m = 0.01 + exp_weights = torch.exp(-m * torch.arange(len(seq))) + print(exp_weights) + + # Calculate offline + avg = (exp_weights * seq).sum() / exp_weights.sum() + print("offline", avg) + + # Calculate online + for i, item in enumerate(seq): + if i == 0: + avg = item + continue + avg *= exp_weights[:i].sum() + avg += item * exp_weights[i] + avg /= exp_weights[:i+1].sum() + print("online", avg) + ``` + """ + self.chunk_size = chunk_size + self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) + self.exe_steps = exe_steps + self.reset() + + def reset(self): + """Resets the online computation variables.""" + self.ensembled_actions = None + # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence. + self.ensembled_actions_count = None + + def update(self, actions): + """ + Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all + time steps, and pop/return the next batch of actions in the sequence. + """ + self.ensemble_weights = self.ensemble_weights.to(device=actions.device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + if self.ensembled_actions is None: + # Initializes `self._ensembled_action` to the sequence of actions predicted during the first + # time step of the episode. + self.ensembled_actions = actions.clone() + # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor + # operations later. + self.ensembled_actions_count = torch.ones( + (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device + ) + else: + # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute + # the online update for those entries. + self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] + self.ensembled_actions += ( + actions[:, : -self.exe_steps] * self.ensemble_weights[self.ensembled_actions_count] + ) + self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] + self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) + # The last action, which has no prior online average, needs to get concatenated onto the end. + self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -self.exe_steps :]], dim=1) + self.ensembled_actions_count = torch.cat( + # [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-self.exe_steps:])] + [ + self.ensembled_actions_count, + torch.ones((self.exe_steps, 1), dtype=torch.long, device=self.ensembled_actions_count.device), + ] + ) + # "Consume" the first action. + + actions, self.ensembled_actions, self.ensembled_actions_count = ( + self.ensembled_actions[:, : self.exe_steps], + self.ensembled_actions[:, self.exe_steps :], + self.ensembled_actions_count[self.exe_steps :], + ) + return actions + + +@dataclass +class VideoFrame: + """ + Provides a type for a dataset containing video frames. + + Example: + + ```python + data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}] + features = {"image": VideoFrame()} + Dataset.from_dict(data_dict, features=Features(features)) + ``` + """ + + pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()}) + _type: str = field(default="VideoFrame", init=False, repr=False) + + def __call__(self): + return self.pa_type + + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + "'register_feature' is experimental and might be subject to breaking changes in the future.", + category=UserWarning, + ) + # to make VideoFrame available in HuggingFace `datasets` + register_feature(VideoFrame, "VideoFrame") + + +def get_image(cam_list, target_shape=None, save_image=False): + curr_images = [] + for cam in cam_list: + color, _ = cam.get_frame() + if save_image: + cv2.imwrite("/home/world-model-x/output.png", color) + color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) + if target_shape: + color = cv2.resize(color, target_shape) + curr_images.append(color) + curr_images = np.stack(curr_images, axis=0) + return curr_images + + +def load_action_from_dataset(dataset_dir, episode_id): + data = load_from_disk(dataset_dir + "/train") + episode_data = load_file(dataset_dir + "/meta_data/episode_data_index.safetensors") + start_id = episode_data["from"][episode_id] + end_id = episode_data["to"][episode_id] + actions = torch.FloatTensor(data["action"][start_id:end_id]) + return actions + + +def load_stats_from_prompt_dir(dataset_dir, prompt_dir, subdir=""): + dataset_dir += subdir + "/meta_data" + stats = load_file(dataset_dir + "/stats.safetensors") + return stats + + +def populate_queues(queues, batch): + for key in batch: + # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the + # queues have the keys they want). + if key not in queues: + continue + if len(queues[key]) != queues[key].maxlen: + # initialize by copying the first observation several times until the queue is full + while len(queues[key]) != queues[key].maxlen: + queues[key].append(batch[key]) + else: + # add latest observation to the queue + queues[key].append(batch[key]) + return queues + + +def action_safe_checking(action, action_max, action_min, threshold=0.01): + over_max = any(action - threshold > action_max.cpu().numpy()) + over_min = any(action + threshold < action_min.cpu().numpy()) + return not (over_max or over_min) + + +def get_init_pose(dataset_dir, start_id=0): + # load all par + dataset_dir_path = Path(dataset_dir) / "data" / "chunk-000" + parquet_files = list(dataset_dir_path.glob("*.parquet")) + parquet_files = sorted([str(f) for f in parquet_files]) + first_rows = [pd.read_parquet(f, engine="pyarrow").iloc[[0]] for f in parquet_files] + df = pd.concat(first_rows, ignore_index=True) + action_array = np.stack(df["action"].values) + init_pose = action_array[192:193, ...] + return init_pose + + +def save_image(obs, num_step=None, output_dir=None): + rgb_image = cv2.cvtColor(obs.observation["images"]["cam_left_high"], cv2.COLOR_BGR2RGB) + cv2.imwrite(f"{output_dir}/top_{num_step:06d}.png", rgb_image) + + +def log_to_tensorboard(writer, data, tag, fps=10): + if isinstance(data, torch.Tensor) and data.dim() == 5: + video = data + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video + ] # [3, n*h, 1*w] + grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] + grid = (grid + 1.0) / 2.0 + grid = grid.unsqueeze(dim=0) + writer.add_video(tag, grid, fps=fps) diff --git a/unitree_deploy/unitree_deploy/utils/joint_trajcetory_inter.py b/unitree_deploy/unitree_deploy/utils/joint_trajcetory_inter.py new file mode 100644 index 0000000..fcd2d3f --- /dev/null +++ b/unitree_deploy/unitree_deploy/utils/joint_trajcetory_inter.py @@ -0,0 +1,196 @@ +"""The modification is derived from diffusion_policy/common/pose_trajectory_interpolator.py. Thank you for the outstanding contribution.""" + +import numbers +from typing import Union + +import numpy as np +import scipy.interpolate as si + + +def joint_pose_distance(start_joint_angles, end_joint_angles): + start_joint_angles = np.array(start_joint_angles) + end_joint_angles = np.array(end_joint_angles) + joint_angle_dist = np.linalg.norm(end_joint_angles - start_joint_angles) + + return joint_angle_dist + + +class JointTrajectoryInterpolator: + def __init__(self, times: np.ndarray, joint_positions: np.ndarray): + assert len(times) >= 1 + assert len(joint_positions) == len(times) + self.num_joints = len(joint_positions[0]) + if not isinstance(times, np.ndarray): + times = np.array(times) + if not isinstance(joint_positions, np.ndarray): + joint_positions = np.array(joint_positions) + if len(times) == 1: + self.single_step = True + self._times = times + self._joint_positions = joint_positions + else: + self.single_step = False + assert np.all(times[1:] >= times[:-1]) + self.interpolators = si.interp1d(times, joint_positions, axis=0, assume_sorted=True) + + @property + def times(self) -> np.ndarray: + if self.single_step: + return self._times + else: + return self.interpolators.x + + @property + def joint_positions(self) -> np.ndarray: + if self.single_step: + return self._joint_positions + else: + n = len(self.times) + joint_positions = np.zeros((n, self.num_joints)) + joint_positions = self.interpolators.y + return joint_positions + + def trim(self, start_t: float, end_t: float) -> "JointTrajectoryInterpolator": + assert start_t <= end_t + times = self.times + should_keep = (start_t < times) & (times < end_t) + keep_times = times[should_keep] + all_times = np.concatenate([[start_t], keep_times, [end_t]]) + all_times = np.unique(all_times) + all_joint_positions = self(all_times) + return JointTrajectoryInterpolator(times=all_times, joint_positions=all_joint_positions) + + def drive_to_waypoint( + self, + pose, + time, + curr_time, + max_pos_speed=np.inf, + ) -> "JointTrajectoryInterpolator": + assert max_pos_speed > 0 + time = max(time, curr_time) + + curr_pose = self(curr_time) + pos_dist = joint_pose_distance(curr_pose, pose) + pos_min_duration = pos_dist / max_pos_speed + duration = time - curr_time + duration = max(duration, pos_min_duration) + assert duration >= 0 + last_waypoint_time = curr_time + duration + + # insert new pose + trimmed_interp = self.trim(curr_time, curr_time) + times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0) + poses = np.append(trimmed_interp.joint_positions, [pose], axis=0) + + # create new interpolator + final_interp = JointTrajectoryInterpolator(times, poses) + return final_interp + + def schedule_waypoint( + self, pose, time, max_pos_speed=np.inf, curr_time=None, last_waypoint_time=None + ) -> "JointTrajectoryInterpolator": + assert max_pos_speed > 0 + if last_waypoint_time is not None: + assert curr_time is not None + + # trim current interpolator to between curr_time and last_waypoint_time + start_time = self.times[0] + end_time = self.times[-1] + assert start_time <= end_time + + if curr_time is not None: + if time <= curr_time: + # if insert time is earlier than current time + # no effect should be done to the interpolator + return self + # now, curr_time < time + start_time = max(curr_time, start_time) + + if last_waypoint_time is not None: + # if last_waypoint_time is earlier than start_time + # use start_time + end_time = curr_time if time <= last_waypoint_time else max(last_waypoint_time, curr_time) + else: + end_time = curr_time + + end_time = min(end_time, time) + start_time = min(start_time, end_time) + + # end time should be the latest of all times except time after this we can assume order (proven by zhenjia, due to the 2 min operations) + # Constraints: + # start_time <= end_time <= time (proven by zhenjia) + # curr_time <= start_time (proven by zhenjia) + # curr_time <= time (proven by zhenjia) + + assert start_time <= end_time + assert end_time <= time + if last_waypoint_time is not None: + if time <= last_waypoint_time: + assert end_time == curr_time + else: + assert end_time == max(last_waypoint_time, curr_time) + + if curr_time is not None: + assert curr_time <= start_time + assert curr_time <= time + + trimmed_interp = self.trim(start_time, end_time) + + # determine speed + duration = time - end_time + end_pose = trimmed_interp(end_time) + pos_dist = joint_pose_distance(pose, end_pose) + + joint_min_duration = pos_dist / max_pos_speed + + duration = max(duration, joint_min_duration) + assert duration >= 0 + last_waypoint_time = end_time + duration + + # insert new pose + times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0) + poses = np.append(trimmed_interp.joint_positions, [pose], axis=0) + + # create new interpolator + final_interp = JointTrajectoryInterpolator(times, poses) + return final_interp + + def __call__(self, t: Union[numbers.Number, np.ndarray]) -> np.ndarray: + is_single = False + if isinstance(t, numbers.Number): + is_single = True + t = np.array([t]) + + joint_positions = np.zeros((len(t), self.num_joints)) + + if self.single_step: + joint_positions[:] = self._joint_positions[0] + else: + start_time = self.times[0] + end_time = self.times[-1] + t = np.clip(t, start_time, end_time) + joint_positions[:, :] = self.interpolators(t) + + if is_single: + joint_positions = joint_positions[0] + return joint_positions + + +def generate_joint_positions( + num_rows: int, num_cols: int, start: float = 0.0, step: float = 0.1, row_offset: float = 0.1 +) -> np.ndarray: + base_row = np.arange(start, start + step * num_cols, step) + array = np.vstack([base_row + i * row_offset for i in range(num_rows)]) + return array + + +if __name__ == "__main__": + # Example joint trajectory data (time in seconds, joint positions as an array of NUM_JOINTS joint angles) + times = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) + joint_positions = generate_joint_positions(num_rows=5, num_cols=7, start=0.0, step=0.1, row_offset=0.1) + interpolator = JointTrajectoryInterpolator(times, joint_positions) + # Get joint positions at a specific time (e.g., t = 2.5 seconds) + t = 0.1 + joint_pos_at_t = interpolator(t) + print("Joint positions at time", t, ":", joint_pos_at_t) diff --git a/unitree_deploy/unitree_deploy/utils/rerun_visualizer.py b/unitree_deploy/unitree_deploy/utils/rerun_visualizer.py new file mode 100644 index 0000000..00fdf20 --- /dev/null +++ b/unitree_deploy/unitree_deploy/utils/rerun_visualizer.py @@ -0,0 +1,175 @@ +from datetime import datetime +from typing import Any, Dict, Optional, Tuple + +import rerun as rr +import rerun.blueprint as rrb +import torch + + +class RerunLogger: + """ + A fully automatic Rerun logger designed to parse and visualize step + dictionaries directly from a LeRobotDataset. + """ + + def __init__( + self, + prefix: str = "", + memory_limit: str = "200MB", + idxrangeboundary: Optional[int] = 300, + ): + """Initializes the Rerun logger.""" + # Use a descriptive name for the Rerun recording + rr.init(f"Dataset_Log_{datetime.now().strftime('%Y%m%d_%H%M%S')}") + rr.spawn(memory_limit=memory_limit) + + self.prefix = prefix + self.blueprint_sent = False + self.idxrangeboundary = idxrangeboundary + + # --- Internal cache for discovered keys --- + self._image_keys: Tuple[str, ...] = () + self._state_key: str = "" + self._action_key: str = "" + self._index_key: str = "index" + self._task_key: str = "task" + self._episode_index_key: str = "episode_index" + + self.current_episode = -1 + + def _initialize_from_data(self, step_data: Dict[str, Any]): + """Inspects the first data dictionary to discover components and set up the blueprint.""" + print("RerunLogger: First data packet received. Auto-configuring...") + + image_keys = [] + for key, value in step_data.items(): + if key.startswith("observation.images.") and isinstance(value, torch.Tensor) and value.ndim > 2: + image_keys.append(key) + elif key == "observation.state": + self._state_key = key + elif key == "action": + self._action_key = key + + self._image_keys = tuple(sorted(image_keys)) + + if "index" in step_data: + self._index_key = "index" + elif "frame_index" in step_data: + self._index_key = "frame_index" + + print(f" - Using '{self._index_key}' for time sequence.") + print(f" - Detected State Key: '{self._state_key}'") + print(f" - Detected Action Key: '{self._action_key}'") + print(f" - Detected Image Keys: {self._image_keys}") + if self.idxrangeboundary: + self.setup_blueprint() + + def setup_blueprint(self): + """Sets up and sends the Rerun blueprint based on detected components.""" + views = [] + + for key in self._image_keys: + clean_name = key.replace("observation.images.", "") + entity_path = f"{self.prefix}images/{clean_name}" + views.append(rrb.Spatial2DView(origin=entity_path, name=clean_name)) + + if self._state_key: + entity_path = f"{self.prefix}state" + views.append( + rrb.TimeSeriesView( + origin=entity_path, + name="Observation State", + time_ranges=[ + rrb.VisibleTimeRange( + "frame", + start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary), + end=rrb.TimeRangeBoundary.cursor_relative(), + ) + ], + plot_legend=rrb.PlotLegend(visible=True), + ) + ) + + if self._action_key: + entity_path = f"{self.prefix}action" + views.append( + rrb.TimeSeriesView( + origin=entity_path, + name="Action", + time_ranges=[ + rrb.VisibleTimeRange( + "frame", + start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary), + end=rrb.TimeRangeBoundary.cursor_relative(), + ) + ], + plot_legend=rrb.PlotLegend(visible=True), + ) + ) + + if not views: + print("Warning: No visualizable components detected in the data.") + return + + grid = rrb.Grid(contents=views) + rr.send_blueprint(grid) + self.blueprint_sent = True + + def log_step(self, step_data: Dict[str, Any]): + """Logs a single step dictionary from your dataset.""" + if not self.blueprint_sent: + self._initialize_from_data(step_data) + + if self._index_key in step_data: + current_index = step_data[self._index_key].item() + rr.set_time_sequence("frame", current_index) + + episode_idx = step_data.get(self._episode_index_key, torch.tensor(-1)).item() + if episode_idx != self.current_episode: + self.current_episode = episode_idx + task_name = step_data.get(self._task_key, "Unknown Task") + log_text = f"Starting Episode {self.current_episode}: {task_name}" + rr.log(f"{self.prefix}info/task", rr.TextLog(log_text, level=rr.TextLogLevel.INFO)) + + for key in self._image_keys: + if key in step_data: + image_tensor = step_data[key] + if image_tensor.ndim > 2: + clean_name = key.replace("observation.images.", "") + entity_path = f"{self.prefix}images/{clean_name}" + if image_tensor.shape[0] in [1, 3, 4]: + image_tensor = image_tensor.permute(1, 2, 0) + rr.log(entity_path, rr.Image(image_tensor)) + + if self._state_key in step_data: + state_tensor = step_data[self._state_key] + entity_path = f"{self.prefix}state" + for i, val in enumerate(state_tensor): + rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item())) + + if self._action_key in step_data: + action_tensor = step_data[self._action_key] + entity_path = f"{self.prefix}action" + for i, val in enumerate(action_tensor): + rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item())) + + +def visualization_data(idx, observation, state, action, online_logger): + item_data: Dict[str, Any] = { + "index": torch.tensor(idx), + "observation.state": state, + "action": action, + } + for k, v in observation.items(): + if k not in ("index", "observation.state", "action"): + item_data[k] = v + # print(item_data) + online_logger.log_step(item_data) + + +def flatten_images(obs: dict) -> dict: + flat = {} + if "images" in obs: + for k, v in obs["images"].items(): + flat[f"observation.images.{k}"] = torch.from_numpy(v) + return flat diff --git a/unitree_deploy/unitree_deploy/utils/rich_logger.py b/unitree_deploy/unitree_deploy/utils/rich_logger.py new file mode 100644 index 0000000..10188a9 --- /dev/null +++ b/unitree_deploy/unitree_deploy/utils/rich_logger.py @@ -0,0 +1,180 @@ +import time + +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, +) +from rich.text import Text + + +class RichLogger: + def __init__(self, level: str = "INFO"): + # Initialize the console for rich output + self.console = Console() + + # Define log levels with corresponding priority + self.levels = { + "DEBUG": 0, # Lowest level, all logs are displayed + "INFO": 1, # Standard level, displays Info and higher + "SUCCESS": 2, # Displays success and higher priority logs + "WARNING": 3, # Displays warnings and errors + "ERROR": 4, # Highest level, only errors are shown + } + + # Set default log level, use INFO if the level is invalid + self.level = self.levels.get(level.upper(), 1) + + def _log(self, level: str, message: str, style: str, emoji=None): + # Check if the current log level allows this message to be printed + if self.levels[level] < self.levels["INFO"]: + return + + # Format the timestamp + timestamp = time.strftime("%Y-%m-%d %H:%M:%S") + + # Create a styled message + text = Text(f"[{timestamp}] [{level}] {message}", style=style) + + # Print the message to the console + self.console.print(text) + + def _log(self, level: str, message: str, style: str, emoji: str = None): + # Check if the current log level allows this message to be printed + if self.levels[level] < self.levels["INFO"]: + return + + # Format the timestamp + timestamp = time.strftime("%Y-%m-%d %H:%M:%S") + + # If emoji is provided, prepend it to the message + if emoji: + message = f"{emoji} {message}" + + # Create a styled message + text = Text(f"[{timestamp}] [{level}] {message}", style=style) + + # Print the message to the console + self.console.print(text) + + # Basic log methods + def info(self, message: str, emoji: str | None = None): + # If the level is INFO or higher, print info log + if self.levels["INFO"] >= self.level: + self._log("INFO", message, "bold cyan", emoji) + + def warning(self, message: str, emoji: str = "⚠️"): + # If the level is WARNING or higher, print warning log + if self.levels["WARNING"] >= self.level: + self._log("WARNING", message, "bold yellow", emoji) + + def error(self, message: str, emoji: str = "❌"): + # If the level is ERROR or higher, print error log + if self.levels["ERROR"] >= self.level: + self._log("ERROR", message, "bold red", emoji) + + def success(self, message: str, emoji: str = "🚀"): + # If the level is SUCCESS or higher, print success log + if self.levels["SUCCESS"] >= self.level: + self._log("SUCCESS", message, "bold green", emoji) + + def debug(self, message: str, emoji: str = "🔍"): + # If the level is DEBUG or higher, print debug log + if self.levels["DEBUG"] >= self.level: + self._log("DEBUG", message, "dim", emoji) + + # ========== Extended Features ========== + # Display a message with an emoji + def emoji(self, message: str, emoji: str = "🚀"): + self.console.print(f"{emoji} {message}", style="bold magenta") + + # Show a loading animation for a certain period + def loading(self, message: str, seconds: float = 2.0): + # Display a loading message with a spinner animation + with self.console.status(f"[bold blue]{message}...", spinner="dots"): + time.sleep(seconds) + + # Show a progress bar for small tasks + def progress(self, task_description: str, total: int = 100, speed: float = 0.02): + # Create and display a progress bar with time elapsed + with Progress( + SpinnerColumn(), + BarColumn(bar_width=None), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + console=self.console, + ) as progress: + # Add a task to the progress bar + task = progress.add_task(f"[cyan]{task_description}", total=total) + while not progress.finished: + progress.update(task, advance=1) + time.sleep(speed) + + +# ========== Singleton Logger Instance ========== +_logger = RichLogger() + + +# ========== Function-style API ========== +def log_info(message: str, emoji: str | None = None): + _logger.info(message=message, emoji=emoji) + + +def log_success(message: str, emoji: str = "🚀"): + _logger.success(message=message, emoji=emoji) + + +def log_warning(message: str, emoji: str = "⚠️"): + _logger.warning(message=message, emoji=emoji) + + +def log_error(message: str, emoji: str = "❌"): + _logger.error(message=message, emoji=emoji) + + +def log_debug(message: str, emoji: str = "🔍"): + _logger.debug(message=message, emoji=emoji) + + +def log_emoji(message: str, emoji: str = "🚀"): + _logger.emoji(message, emoji) + + +def log_loading(message: str, seconds: float = 2.0): + _logger.loading(message, seconds) + + +def log_progress(task_description: str, total: int = 100, speed: float = 0.02): + _logger.progress(task_description, total, speed) + + +if __name__ == "__main__": + # Example usage: + # Initialize logger instance + logger = RichLogger(level="INFO") # Set initial log level to INFO + + # Log at different levels + logger.info("System initialization complete.") + logger.success("Robot started successfully!") + logger.warning("Warning: Joint temperature high!") + logger.error("Error: Failed to connect to robot") + logger.debug("Debug: Initializing motor controllers") + + # Display an emoji message + logger.emoji("This is a fun message with an emoji!", emoji="🔥") + + # Display loading animation for 3 seconds + logger.loading("Loading motor control data...", seconds=3) + + # Show progress bar for a task with 100 steps + logger.progress("Processing task", total=100, speed=0.05) + + # You can also use different log levels with a higher level than INFO, like ERROR: + logger = RichLogger(level="ERROR") + + # Only error and higher priority logs will be shown (INFO, SUCCESS, WARNING will be hidden) + logger.info("This won't be displayed because the level is set to ERROR") + logger.error("This error will be displayed!") diff --git a/unitree_deploy/unitree_deploy/utils/run_simulation.py b/unitree_deploy/unitree_deploy/utils/run_simulation.py new file mode 100644 index 0000000..6454882 --- /dev/null +++ b/unitree_deploy/unitree_deploy/utils/run_simulation.py @@ -0,0 +1,171 @@ +import time +from dataclasses import dataclass +from multiprocessing import Process, Queue +from queue import Empty + +import mujoco +import mujoco.viewer +import numpy as np + +from unitree_deploy.utils.rich_logger import log_info, log_success + + +@dataclass +class MujocoSimulationConfig: + xml_path: str + dof: int + robot_type: str + ctr_dof: int + stop_dof: int + + +def get_mujoco_sim_config(robot_type: str) -> MujocoSimulationConfig: + if robot_type == "g1": + return MujocoSimulationConfig( + xml_path="unitree_deploy/robot_devices/assets/g1/g1_body29.xml", + dof=30, + robot_type="g1", + ctr_dof=14, + stop_dof=35, + ) + elif robot_type == "z1": + return MujocoSimulationConfig( + xml_path="unitree_deploy/robot_devices/assets/z1/z1.xml", + dof=6, + robot_type="z1", + ctr_dof=6, + stop_dof=6, + ) + elif robot_type == "h1_2": + return MujocoSimulationConfig( + xml_path="unitree_deploy/robot_devices/assets/z1/z1.urdf", + dof=30, + robot_type="g1", + ctr_dof=14, + stop_dof=35, + ) + else: + raise ValueError(f"Unsupported robot_type: {robot_type}") + + +class MujicoSimulation: + def __init__(self, config: MujocoSimulationConfig): + self.xml_path = config.xml_path + + self.robot_type = config.robot_type + + self.dof = config.dof + self.ctr_dof = config.ctr_dof + self.stop_dof = config.stop_dof + + self.action_queue = Queue() + self.state_queue = Queue() + self.process = Process(target=self._run_simulation, args=(self.xml_path, self.action_queue, self.state_queue)) + self.process.daemon = True + self.process.start() + + def set_positions(self, joint_positions: np.ndarray): + if joint_positions.shape[0] != self.ctr_dof: + raise ValueError(f"joint_positions must contain {self.ctr_dof} values!") + + if self.robot_type == "g1": + joint_positions = np.concatenate([np.zeros(self.dof - self.ctr_dof, dtype=np.float32), joint_positions]) + elif self.robot_type == "z1": + pass + elif self.robot_type == "h1_2": + joint_positions[: self.dof - self.ctr_dof] = 0.0 + else: + raise ValueError(f"Unsupported robot_type: {self.robot_type}") + + self.action_queue.put(joint_positions.tolist()) + + def get_current_positions(self, timeout=0.01): + try: + return self.state_queue.get(timeout=timeout) + except Empty: + return [0.0] * self.stop_dof + + def stop(self): + if hasattr(self, "process") and self.process is not None and self.process.is_alive(): + try: + self.process.terminate() + self.process.join() + except Exception as e: + print(f"[WARN] Failed to stop process: {e}") + self.process = None + + for qname in ["action_queue", "state_queue"]: + queue = getattr(self, qname, None) + if queue is not None: + try: + if hasattr(queue, "close") and callable(queue.close): + queue.close() + if hasattr(queue, "join_thread") and callable(queue.join_thread): + queue.join_thread() + except Exception as e: + print(f"[WARN] Failed to cleanup {qname}: {e}") + setattr(self, qname, None) + + def __del__(self): + self.stop() + + @staticmethod + def _run_simulation(xml_path: str, action_queue: Queue, state_queue: Queue): + model = mujoco.MjModel.from_xml_path(xml_path) + data = mujoco.MjData(model) + + joint_names = [mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, i) for i in range(model.njnt)] + joints_indices = [ + model.jnt_qposadr[mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, name)] for name in joint_names + ] + log_info(f"len joints indices: {len(joints_indices)}") + + viewer = mujoco.viewer.launch_passive(model, data) + + current_positions = np.zeros(len(joints_indices), dtype=np.float32) + try: + while viewer.is_running(): + try: + new_pos = action_queue.get_nowait() + if len(new_pos) == len(joints_indices): + current_positions = new_pos + except Empty: + pass + + for idx, pos in zip(joints_indices, current_positions, strict=True): + data.qpos[idx] = pos + + data.qvel[:] = 0 + mujoco.mj_forward(model, data) + + state_queue.put(data.qpos.copy()) + + viewer.sync() + time.sleep(0.001) + + except KeyboardInterrupt: + log_success("The simulation process was interrupted.") + finally: + viewer.close() + + +def main(): + config = get_mujoco_sim_config(robot_type="g1") + sim = MujicoSimulation(config) + time.sleep(1) # Allow time for the simulation to start + try: + while True: + positions = np.random.uniform(-1.0, 1.0, sim.ctr_dof) + + sim.set_positions(positions) + + # print(sim.get_current_positions()) + + time.sleep(1 / 50) + except KeyboardInterrupt: + print("Simulation stopped.") + sim.stop() + + +if __name__ == "__main__": + main() diff --git a/unitree_deploy/unitree_deploy/utils/trajectory_generator.py b/unitree_deploy/unitree_deploy/utils/trajectory_generator.py new file mode 100644 index 0000000..ea3958d --- /dev/null +++ b/unitree_deploy/unitree_deploy/utils/trajectory_generator.py @@ -0,0 +1,36 @@ +import math + +import numpy as np +import pinocchio as pin + + +def generate_rotation(step: int, rotation_speed: float, max_step: int = 240): + """Generate rotation (quaternions) and translation deltas for left and right arm motions.""" + angle = rotation_speed * step if step <= max_step // 2 else rotation_speed * (max_step - step) + + # Create rotation quaternion for left arm (around Y-axis) + l_quat = pin.Quaternion(np.cos(angle / 2), 0, np.sin(angle / 2), 0) + + # Create rotation quaternion for right arm (around Z-axis) + r_quat = pin.Quaternion(np.cos(angle / 2), 0, 0, np.sin(angle / 2)) + + # Define translation increments for left and right arm + delta_l = np.array([0.001, 0.001, 0.001]) * 1.2 + delta_r = np.array([0.001, -0.001, 0.001]) * 1.2 + + # Reverse direction in second half of cycle + if step > max_step // 2: + delta_l *= -1 + delta_r *= -1 + + return l_quat, r_quat, delta_l, delta_r + + +def sinusoidal_single_gripper_motion(period: float, amplitude: float, current_time: float) -> np.ndarray: + value = amplitude * (math.sin(2 * math.pi * current_time / period) + 1) / 2 + return np.array([value*5]) + + +def sinusoidal_gripper_motion(period: float, amplitude: float, current_time: float) -> np.ndarray: + value = amplitude * (math.sin(2 * math.pi * current_time / period) + 1) / 2 + return np.array([value]*5) diff --git a/unitree_deploy/unitree_deploy/utils/weighted_moving_filter.py b/unitree_deploy/unitree_deploy/utils/weighted_moving_filter.py new file mode 100644 index 0000000..e87ac95 --- /dev/null +++ b/unitree_deploy/unitree_deploy/utils/weighted_moving_filter.py @@ -0,0 +1,40 @@ +import numpy as np + + +class WeightedMovingFilter: + def __init__(self, weights, data_size=14): + self._window_size = len(weights) + self._weights = np.array(weights) + assert np.isclose(np.sum(self._weights), 1.0), ( + "[WeightedMovingFilter] the sum of weights list must be 1.0!" + ) + self._data_size = data_size + self._filtered_data = np.zeros(self._data_size) + self._data_queue = [] + + def _apply_filter(self): + if len(self._data_queue) < self._window_size: + return self._data_queue[-1] + + data_array = np.array(self._data_queue) + temp_filtered_data = np.zeros(self._data_size) + for i in range(self._data_size): + temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1] + + return temp_filtered_data + + def add_data(self, new_data): + assert len(new_data) == self._data_size + + if len(self._data_queue) > 0 and np.array_equal(new_data, self._data_queue[-1]): + return # skip duplicate data + + if len(self._data_queue) >= self._window_size: + self._data_queue.pop(0) + + self._data_queue.append(new_data) + self._filtered_data = self._apply_filter() + + @property + def filtered_data(self): + return self._filtered_data diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768634421.node-0.53101.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768634421.node-0.53101.0 new file mode 100644 index 0000000..8b96762 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768634421.node-0.53101.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768635752.node-0.54173.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768635752.node-0.54173.0 new file mode 100644 index 0000000..7069bdf Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768635752.node-0.54173.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768635895.node-0.54367.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768635895.node-0.54367.0 new file mode 100644 index 0000000..648d106 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768635895.node-0.54367.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768635978.node-0.54559.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768635978.node-0.54559.0 new file mode 100644 index 0000000..a6d6800 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768635978.node-0.54559.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768637429.node-0.56046.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768637429.node-0.56046.0 new file mode 100644 index 0000000..1bb06ec Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768637429.node-0.56046.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768641146.node-0.58484.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768641146.node-0.58484.0 new file mode 100644 index 0000000..442cd78 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768641146.node-0.58484.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768641226.node-0.58658.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768641226.node-0.58658.0 new file mode 100644 index 0000000..b5fb445 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768641226.node-0.58658.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768649717.node-0.69434.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768649717.node-0.69434.0 new file mode 100644 index 0000000..3f9268a Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768649717.node-0.69434.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768651877.node-0.87172.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768651877.node-0.87172.0 new file mode 100644 index 0000000..7ab6920 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768651877.node-0.87172.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768651952.node-0.87853.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768651952.node-0.87853.0 new file mode 100644 index 0000000..360db40 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768651952.node-0.87853.0 differ diff --git a/unitree_g1_pack_camera/case1/run_world_model_interaction_profile.sh b/unitree_g1_pack_camera/case1/run_world_model_interaction_profile.sh new file mode 100644 index 0000000..d29da6b --- /dev/null +++ b/unitree_g1_pack_camera/case1/run_world_model_interaction_profile.sh @@ -0,0 +1,34 @@ +res_dir="unitree_g1_pack_camera/case1" +dataset="unitree_g1_pack_camera" + +{ + time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \ + --seed 123 \ + --ckpt_path ckpts/unifolm_wma_dual.ckpt \ + --config configs/inference/world_model_interaction.yaml \ + --savedir "${res_dir}/output" \ + --bs 1 --height 320 --width 512 \ + --unconditional_guidance_scale 1.0 \ + --ddim_steps 50 \ + --ddim_eta 1.0 \ + --prompt_dir "unitree_g1_pack_camera/case1/world_model_interaction_prompts" \ + --dataset ${dataset} \ + --video_length 16 \ + --frame_stride 6 \ + --n_action_steps 16 \ + --exe_steps 16 \ + --n_iter 11 \ + --timestep_spacing 'uniform_trailing' \ + --guidance_rescale 0.7 \ + --perframe_ae \ + --profile \ + --profile_iterations 3 +} 2>&1 | tee "${res_dir}/output_profile.log" + +echo "" +echo "========================================" +echo "Profiling results saved to: ${res_dir}/output/profile_output/" +echo " - profiling_report.txt: Human-readable summary" +echo " - profiling_data.json: Detailed data for analysis" +echo " - trace_iter_*.json: Chrome trace files (open in chrome://tracing)" +echo "========================================" diff --git a/unitree_g1_pack_camera/case2/output/tensorboard/events.out.tfevents.1768650624.node-0.77016.0 b/unitree_g1_pack_camera/case2/output/tensorboard/events.out.tfevents.1768650624.node-0.77016.0 new file mode 100644 index 0000000..40ee13a Binary files /dev/null and b/unitree_g1_pack_camera/case2/output/tensorboard/events.out.tfevents.1768650624.node-0.77016.0 differ diff --git a/unitree_g1_pack_camera/case3/output/tensorboard/events.out.tfevents.1768651549.node-0.84804.0 b/unitree_g1_pack_camera/case3/output/tensorboard/events.out.tfevents.1768651549.node-0.84804.0 new file mode 100644 index 0000000..95b357c Binary files /dev/null and b/unitree_g1_pack_camera/case3/output/tensorboard/events.out.tfevents.1768651549.node-0.84804.0 differ diff --git a/unitree_g1_pack_camera/case4/output/tensorboard/events.out.tfevents.1768652463.node-0.90349.0 b/unitree_g1_pack_camera/case4/output/tensorboard/events.out.tfevents.1768652463.node-0.90349.0 new file mode 100644 index 0000000..a141f30 Binary files /dev/null and b/unitree_g1_pack_camera/case4/output/tensorboard/events.out.tfevents.1768652463.node-0.90349.0 differ