第一次完整测例跑完
133
.gitignore
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.venv
|
||||
venv/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
*.pdf
|
||||
.pdf
|
||||
plot_test/
|
||||
plot/
|
||||
performance/
|
||||
localTest/
|
||||
fig/
|
||||
figure/
|
||||
*.mp4
|
||||
*.json
|
||||
Data/ControlVAE.yml
|
||||
Data/Misc
|
||||
Data/Pretrained
|
||||
Data/utils.py
|
||||
Experiment/checkpoint
|
||||
Experiment/log
|
||||
*.ckpt
|
||||
*.STL
|
||||
*.gif
|
||||
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "external/dlimp"]
|
||||
path = external/dlimp
|
||||
url = https://github.com/kvablack/dlimp
|
||||
439
LICENSE
Normal file
@@ -0,0 +1,439 @@
|
||||
Attribution-NonCommercial-ShareAlike 4.0 International
|
||||
|
||||
Copyright (c) 2016-2025 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics")
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
||||
does not provide legal services or legal advice. Distribution of
|
||||
Creative Commons public licenses does not create a lawyer-client or
|
||||
other relationship. Creative Commons makes its licenses and related
|
||||
information available on an "as-is" basis. Creative Commons gives no
|
||||
warranties regarding its licenses, any material licensed under their
|
||||
terms and conditions, or any related information. Creative Commons
|
||||
disclaims all liability for damages resulting from their use to the
|
||||
fullest extent possible.
|
||||
|
||||
Using Creative Commons Public Licenses
|
||||
|
||||
Creative Commons public licenses provide a standard set of terms and
|
||||
conditions that creators and other rights holders may use to share
|
||||
original works of authorship and other material subject to copyright
|
||||
and certain other rights specified in the public license below. The
|
||||
following considerations are for informational purposes only, are not
|
||||
exhaustive, and do not form part of our licenses.
|
||||
|
||||
Considerations for licensors: Our public licenses are
|
||||
intended for use by those authorized to give the public
|
||||
permission to use material in ways otherwise restricted by
|
||||
copyright and certain other rights. Our licenses are
|
||||
irrevocable. Licensors should read and understand the terms
|
||||
and conditions of the license they choose before applying it.
|
||||
Licensors should also secure all rights necessary before
|
||||
applying our licenses so that the public can reuse the
|
||||
material as expected. Licensors should clearly mark any
|
||||
material not subject to the license. This includes other CC-
|
||||
licensed material, or material used under an exception or
|
||||
limitation to copyright. More considerations for licensors:
|
||||
wiki.creativecommons.org/Considerations_for_licensors
|
||||
|
||||
Considerations for the public: By using one of our public
|
||||
licenses, a licensor grants the public permission to use the
|
||||
licensed material under specified terms and conditions. If
|
||||
the licensor's permission is not necessary for any reason--for
|
||||
example, because of any applicable exception or limitation to
|
||||
copyright--then that use is not regulated by the license. Our
|
||||
licenses grant only permissions under copyright and certain
|
||||
other rights that a licensor has authority to grant. Use of
|
||||
the licensed material may still be restricted for other
|
||||
reasons, including because others have copyright or other
|
||||
rights in the material. A licensor may make special requests,
|
||||
such as asking that all changes be marked or described.
|
||||
Although not required by our licenses, you are encouraged to
|
||||
respect those requests where reasonable. More considerations
|
||||
for the public:
|
||||
wiki.creativecommons.org/Considerations_for_licensees
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
||||
Public License
|
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree
|
||||
to be bound by the terms and conditions of this Creative Commons
|
||||
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
||||
("Public License"). To the extent this Public License may be
|
||||
interpreted as a contract, You are granted the Licensed Rights in
|
||||
consideration of Your acceptance of these terms and conditions, and the
|
||||
Licensor grants You such rights in consideration of benefits the
|
||||
Licensor receives from making the Licensed Material available under
|
||||
these terms and conditions.
|
||||
|
||||
|
||||
Section 1 -- Definitions.
|
||||
|
||||
a. Adapted Material means material subject to Copyright and Similar
|
||||
Rights that is derived from or based upon the Licensed Material
|
||||
and in which the Licensed Material is translated, altered,
|
||||
arranged, transformed, or otherwise modified in a manner requiring
|
||||
permission under the Copyright and Similar Rights held by the
|
||||
Licensor. For purposes of this Public License, where the Licensed
|
||||
Material is a musical work, performance, or sound recording,
|
||||
Adapted Material is always produced where the Licensed Material is
|
||||
synched in timed relation with a moving image.
|
||||
|
||||
b. Adapter's License means the license You apply to Your Copyright
|
||||
and Similar Rights in Your contributions to Adapted Material in
|
||||
accordance with the terms and conditions of this Public License.
|
||||
|
||||
c. BY-NC-SA Compatible License means a license listed at
|
||||
creativecommons.org/compatiblelicenses, approved by Creative
|
||||
Commons as essentially the equivalent of this Public License.
|
||||
|
||||
d. Copyright and Similar Rights means copyright and/or similar rights
|
||||
closely related to copyright including, without limitation,
|
||||
performance, broadcast, sound recording, and Sui Generis Database
|
||||
Rights, without regard to how the rights are labeled or
|
||||
categorized. For purposes of this Public License, the rights
|
||||
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||
Rights.
|
||||
|
||||
e. Effective Technological Measures means those measures that, in the
|
||||
absence of proper authority, may not be circumvented under laws
|
||||
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||
Treaty adopted on December 20, 1996, and/or similar international
|
||||
agreements.
|
||||
|
||||
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||
any other exception or limitation to Copyright and Similar Rights
|
||||
that applies to Your use of the Licensed Material.
|
||||
|
||||
g. License Elements means the license attributes listed in the name
|
||||
of a Creative Commons Public License. The License Elements of this
|
||||
Public License are Attribution, NonCommercial, and ShareAlike.
|
||||
|
||||
h. Licensed Material means the artistic or literary work, database,
|
||||
or other material to which the Licensor applied this Public
|
||||
License.
|
||||
|
||||
i. Licensed Rights means the rights granted to You subject to the
|
||||
terms and conditions of this Public License, which are limited to
|
||||
all Copyright and Similar Rights that apply to Your use of the
|
||||
Licensed Material and that the Licensor has authority to license.
|
||||
|
||||
j. Licensor means the individual(s) or entity(ies) granting rights
|
||||
under this Public License.
|
||||
|
||||
k. NonCommercial means not primarily intended for or directed towards
|
||||
commercial advantage or monetary compensation. For purposes of
|
||||
this Public License, the exchange of the Licensed Material for
|
||||
other material subject to Copyright and Similar Rights by digital
|
||||
file-sharing or similar means is NonCommercial provided there is
|
||||
no payment of monetary compensation in connection with the
|
||||
exchange.
|
||||
|
||||
l. Share means to provide material to the public by any means or
|
||||
process that requires permission under the Licensed Rights, such
|
||||
as reproduction, public display, public performance, distribution,
|
||||
dissemination, communication, or importation, and to make material
|
||||
available to the public including in ways that members of the
|
||||
public may access the material from a place and at a time
|
||||
individually chosen by them.
|
||||
|
||||
m. Sui Generis Database Rights means rights other than copyright
|
||||
resulting from Directive 96/9/EC of the European Parliament and of
|
||||
the Council of 11 March 1996 on the legal protection of databases,
|
||||
as amended and/or succeeded, as well as other essentially
|
||||
equivalent rights anywhere in the world.
|
||||
|
||||
n. You means the individual or entity exercising the Licensed Rights
|
||||
under this Public License. Your has a corresponding meaning.
|
||||
|
||||
|
||||
Section 2 -- Scope.
|
||||
|
||||
a. License grant.
|
||||
|
||||
1. Subject to the terms and conditions of this Public License,
|
||||
the Licensor hereby grants You a worldwide, royalty-free,
|
||||
non-sublicensable, non-exclusive, irrevocable license to
|
||||
exercise the Licensed Rights in the Licensed Material to:
|
||||
|
||||
a. reproduce and Share the Licensed Material, in whole or
|
||||
in part, for NonCommercial purposes only; and
|
||||
|
||||
b. produce, reproduce, and Share Adapted Material for
|
||||
NonCommercial purposes only.
|
||||
|
||||
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||
Exceptions and Limitations apply to Your use, this Public
|
||||
License does not apply, and You do not need to comply with
|
||||
its terms and conditions.
|
||||
|
||||
3. Term. The term of this Public License is specified in Section
|
||||
6(a).
|
||||
|
||||
4. Media and formats; technical modifications allowed. The
|
||||
Licensor authorizes You to exercise the Licensed Rights in
|
||||
all media and formats whether now known or hereafter created,
|
||||
and to make technical modifications necessary to do so. The
|
||||
Licensor waives and/or agrees not to assert any right or
|
||||
authority to forbid You from making technical modifications
|
||||
necessary to exercise the Licensed Rights, including
|
||||
technical modifications necessary to circumvent Effective
|
||||
Technological Measures. For purposes of this Public License,
|
||||
simply making modifications authorized by this Section 2(a)
|
||||
(4) never produces Adapted Material.
|
||||
|
||||
5. Downstream recipients.
|
||||
|
||||
a. Offer from the Licensor -- Licensed Material. Every
|
||||
recipient of the Licensed Material automatically
|
||||
receives an offer from the Licensor to exercise the
|
||||
Licensed Rights under the terms and conditions of this
|
||||
Public License.
|
||||
|
||||
b. Additional offer from the Licensor -- Adapted Material.
|
||||
Every recipient of Adapted Material from You
|
||||
automatically receives an offer from the Licensor to
|
||||
exercise the Licensed Rights in the Adapted Material
|
||||
under the conditions of the Adapter's License You apply.
|
||||
|
||||
c. No downstream restrictions. You may not offer or impose
|
||||
any additional or different terms or conditions on, or
|
||||
apply any Effective Technological Measures to, the
|
||||
Licensed Material if doing so restricts exercise of the
|
||||
Licensed Rights by any recipient of the Licensed
|
||||
Material.
|
||||
|
||||
6. No endorsement. Nothing in this Public License constitutes or
|
||||
may be construed as permission to assert or imply that You
|
||||
are, or that Your use of the Licensed Material is, connected
|
||||
with, or sponsored, endorsed, or granted official status by,
|
||||
the Licensor or others designated to receive attribution as
|
||||
provided in Section 3(a)(1)(A)(i).
|
||||
|
||||
b. Other rights.
|
||||
|
||||
1. Moral rights, such as the right of integrity, are not
|
||||
licensed under this Public License, nor are publicity,
|
||||
privacy, and/or other similar personality rights; however, to
|
||||
the extent possible, the Licensor waives and/or agrees not to
|
||||
assert any such rights held by the Licensor to the limited
|
||||
extent necessary to allow You to exercise the Licensed
|
||||
Rights, but not otherwise.
|
||||
|
||||
2. Patent and trademark rights are not licensed under this
|
||||
Public License.
|
||||
|
||||
3. To the extent possible, the Licensor waives any right to
|
||||
collect royalties from You for the exercise of the Licensed
|
||||
Rights, whether directly or through a collecting society
|
||||
under any voluntary or waivable statutory or compulsory
|
||||
licensing scheme. In all other cases the Licensor expressly
|
||||
reserves any right to collect such royalties, including when
|
||||
the Licensed Material is used other than for NonCommercial
|
||||
purposes.
|
||||
|
||||
|
||||
Section 3 -- License Conditions.
|
||||
|
||||
Your exercise of the Licensed Rights is expressly made subject to the
|
||||
following conditions.
|
||||
|
||||
a. Attribution.
|
||||
|
||||
1. If You Share the Licensed Material (including in modified
|
||||
form), You must:
|
||||
|
||||
a. retain the following if it is supplied by the Licensor
|
||||
with the Licensed Material:
|
||||
|
||||
i. identification of the creator(s) of the Licensed
|
||||
Material and any others designated to receive
|
||||
attribution, in any reasonable manner requested by
|
||||
the Licensor (including by pseudonym if
|
||||
designated);
|
||||
|
||||
ii. a copyright notice;
|
||||
|
||||
iii. a notice that refers to this Public License;
|
||||
|
||||
iv. a notice that refers to the disclaimer of
|
||||
warranties;
|
||||
|
||||
v. a URI or hyperlink to the Licensed Material to the
|
||||
extent reasonably practicable;
|
||||
|
||||
b. indicate if You modified the Licensed Material and
|
||||
retain an indication of any previous modifications; and
|
||||
|
||||
c. indicate the Licensed Material is licensed under this
|
||||
Public License, and include the text of, or the URI or
|
||||
hyperlink to, this Public License.
|
||||
|
||||
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||
reasonable manner based on the medium, means, and context in
|
||||
which You Share the Licensed Material. For example, it may be
|
||||
reasonable to satisfy the conditions by providing a URI or
|
||||
hyperlink to a resource that includes the required
|
||||
information.
|
||||
3. If requested by the Licensor, You must remove any of the
|
||||
information required by Section 3(a)(1)(A) to the extent
|
||||
reasonably practicable.
|
||||
|
||||
b. ShareAlike.
|
||||
|
||||
In addition to the conditions in Section 3(a), if You Share
|
||||
Adapted Material You produce, the following conditions also apply.
|
||||
|
||||
1. The Adapter's License You apply must be a Creative Commons
|
||||
license with the same License Elements, this version or
|
||||
later, or a BY-NC-SA Compatible License.
|
||||
|
||||
2. You must include the text of, or the URI or hyperlink to, the
|
||||
Adapter's License You apply. You may satisfy this condition
|
||||
in any reasonable manner based on the medium, means, and
|
||||
context in which You Share Adapted Material.
|
||||
|
||||
3. You may not offer or impose any additional or different terms
|
||||
or conditions on, or apply any Effective Technological
|
||||
Measures to, Adapted Material that restrict exercise of the
|
||||
rights granted under the Adapter's License You apply.
|
||||
|
||||
|
||||
Section 4 -- Sui Generis Database Rights.
|
||||
|
||||
Where the Licensed Rights include Sui Generis Database Rights that
|
||||
apply to Your use of the Licensed Material:
|
||||
|
||||
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||
to extract, reuse, reproduce, and Share all or a substantial
|
||||
portion of the contents of the database for NonCommercial purposes
|
||||
only;
|
||||
|
||||
b. if You include all or a substantial portion of the database
|
||||
contents in a database in which You have Sui Generis Database
|
||||
Rights, then the database in which You have Sui Generis Database
|
||||
Rights (but not its individual contents) is Adapted Material,
|
||||
including for purposes of Section 3(b); and
|
||||
|
||||
c. You must comply with the conditions in Section 3(a) if You Share
|
||||
all or a substantial portion of the contents of the database.
|
||||
|
||||
For the avoidance of doubt, this Section 4 supplements and does not
|
||||
replace Your obligations under this Public License where the Licensed
|
||||
Rights include other Copyright and Similar Rights.
|
||||
|
||||
|
||||
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||
|
||||
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||
|
||||
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||
|
||||
c. The disclaimer of warranties and limitation of liability provided
|
||||
above shall be interpreted in a manner that, to the extent
|
||||
possible, most closely approximates an absolute disclaimer and
|
||||
waiver of all liability.
|
||||
|
||||
|
||||
Section 6 -- Term and Termination.
|
||||
|
||||
a. This Public License applies for the term of the Copyright and
|
||||
Similar Rights licensed here. However, if You fail to comply with
|
||||
this Public License, then Your rights under this Public License
|
||||
terminate automatically.
|
||||
|
||||
b. Where Your right to use the Licensed Material has terminated under
|
||||
Section 6(a), it reinstates:
|
||||
|
||||
1. automatically as of the date the violation is cured, provided
|
||||
it is cured within 30 days of Your discovery of the
|
||||
violation; or
|
||||
|
||||
2. upon express reinstatement by the Licensor.
|
||||
|
||||
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||
right the Licensor may have to seek remedies for Your violations
|
||||
of this Public License.
|
||||
|
||||
c. For the avoidance of doubt, the Licensor may also offer the
|
||||
Licensed Material under separate terms or conditions or stop
|
||||
distributing the Licensed Material at any time; however, doing so
|
||||
will not terminate this Public License.
|
||||
|
||||
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||
License.
|
||||
|
||||
|
||||
Section 7 -- Other Terms and Conditions.
|
||||
|
||||
a. The Licensor shall not be bound by any additional or different
|
||||
terms or conditions communicated by You unless expressly agreed.
|
||||
|
||||
b. Any arrangements, understandings, or agreements regarding the
|
||||
Licensed Material not stated herein are separate from and
|
||||
independent of the terms and conditions of this Public License.
|
||||
|
||||
|
||||
Section 8 -- Interpretation.
|
||||
|
||||
a. For the avoidance of doubt, this Public License does not, and
|
||||
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||
conditions on any use of the Licensed Material that could lawfully
|
||||
be made without permission under this Public License.
|
||||
|
||||
b. To the extent possible, if any provision of this Public License is
|
||||
deemed unenforceable, it shall be automatically reformed to the
|
||||
minimum extent necessary to make it enforceable. If the provision
|
||||
cannot be reformed, it shall be severed from this Public License
|
||||
without affecting the enforceability of the remaining terms and
|
||||
conditions.
|
||||
|
||||
c. No term or condition of this Public License will be waived and no
|
||||
failure to comply consented to unless expressly agreed to by the
|
||||
Licensor.
|
||||
|
||||
d. Nothing in this Public License constitutes or may be interpreted
|
||||
as a limitation upon, or waiver of, any privileges and immunities
|
||||
that apply to the Licensor or You, including from the legal
|
||||
processes of any jurisdiction or authority.
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons is not a party to its public
|
||||
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
||||
its public licenses to material it publishes and in those instances
|
||||
will be considered the “Licensor.” The text of the Creative Commons
|
||||
public licenses is dedicated to the public domain under the CC0 Public
|
||||
Domain Dedication. Except for the limited purpose of indicating that
|
||||
material is shared under a Creative Commons public license or as
|
||||
otherwise permitted by the Creative Commons policies published at
|
||||
creativecommons.org/policies, Creative Commons does not authorize the
|
||||
use of the trademark "Creative Commons" or any other trademark or logo
|
||||
of Creative Commons without its prior written consent including,
|
||||
without limitation, in connection with any unauthorized modifications
|
||||
to any of its public licenses or any other arrangements,
|
||||
understandings, or agreements concerning use of licensed material. For
|
||||
the avoidance of doubt, this paragraph does not form part of the
|
||||
public licenses.
|
||||
|
||||
Creative Commons may be contacted at creativecommons.org.
|
||||
228
README.md
Normal file
@@ -0,0 +1,228 @@
|
||||
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
|
||||
<p style="font-size: 1.2em;">
|
||||
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>Project Page</strong></a> |
|
||||
<a href="https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c"><strong>Models</strong></a> |
|
||||
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>Dataset</strong></a>
|
||||
</p>
|
||||
<div align="center">
|
||||
<p align="right">
|
||||
<span> 🌎English </span> | <a href="README_cn.md"> 🇨🇳中文 </a>
|
||||
</p>
|
||||
</div>
|
||||
<div align="justify">
|
||||
<b>UnifoLM-WMA-0</b> is Unitree‘s open-source world-model–action architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) <b>Simulation Engine</b> – operates as an interactive simulator to generate synthetic data for robot learning; (b) <b>Policy Enhancement</b> – connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
|
||||
</div>
|
||||
|
||||
## 🦾 Real-Robot Demonstrations
|
||||
| <img src="assets/gifs/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
|:---:|:---:|
|
||||
| <img src="assets/gifs/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
|
||||
**Note: the top-right window shows the world model’s pretion of future action videos.**
|
||||
|
||||
## 🔥 News
|
||||
|
||||
* Sep 22, 2025: 🚀 We released the deployment code for assisting experiments with [Unitree](https://www.unitree.com/) robots.
|
||||
* Sep 15, 2025: 🚀 We released the training and inference code along with the model weights of [**UnifoLM-WMA-0**](https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c).
|
||||
|
||||
## 📑 Opensource Plan
|
||||
- [x] Training
|
||||
- [x] Inference
|
||||
- [x] Checkpoints
|
||||
- [x] Deployment
|
||||
|
||||
## ⚙️ Installation
|
||||
```
|
||||
conda create -n unifolm-wma python==3.10.18
|
||||
conda activate unifolm-wma
|
||||
|
||||
conda install pinocchio=3.2.0 -c conda-forge -y
|
||||
conda install ffmpeg=7.1.1 -c conda-forge
|
||||
|
||||
git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
|
||||
|
||||
# If you already downloaded the repo:
|
||||
cd unifolm-world-model-action
|
||||
git submodule update --init --recursive
|
||||
|
||||
pip install -e .
|
||||
|
||||
cd external/dlimp
|
||||
pip install -e .
|
||||
```
|
||||
## 🧰 Model Checkpoints
|
||||
| Model | Description | Link|
|
||||
|---------|-------|------|
|
||||
|$\text{UnifoLM-WMA-0}_{Base}$| Fine-tuned on [Open-X](https://robotics-transformer-x.github.io/) dataset. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
|
||||
|$\text{UnifoLM-WMA-0}_{Dual}$| Fine-tuned on five [Unitree opensource dataset](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab) in both decision-making and simulation modes. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
|
||||
|
||||
## 🛢️ Dataset
|
||||
In our experiments, we consider the following three opensource dataset:
|
||||
| Dataset | Robot | Link |
|
||||
|---------|-------|------|
|
||||
|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
|
||||
|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
|
||||
|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
|
||||
|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
|
||||
|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
|
||||
|
||||
To train on your own dataset, first to have the data following the [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) dataset format. Assume the dataset’s source directory structure is as follows:
|
||||
```
|
||||
source_dir/
|
||||
├── dataset1_name
|
||||
├── dataset2_name
|
||||
├── dataset3_name
|
||||
└── ...
|
||||
```
|
||||
Then, convert a dataset to the required format using the command below:
|
||||
```python
|
||||
cd prepare_data
|
||||
python prepare_training_data.py \
|
||||
--source_dir /path/to/your/source_dir \
|
||||
--target_dir /path/to/save/the/converted/data \
|
||||
--dataset_name "dataset1_name" \
|
||||
--robot_name "a tag of the robot in the dataset" # e.g, Unitree Z1 Robot Arm or Unitree G1 Robot with Gripper.
|
||||
```
|
||||
The resulting data structure (Note: model training only supports input from the main-view camera. If the dataset includes multiple views, remove the corresponding values from the ```data_dir``` column in the CSV file.
|
||||
```
|
||||
target_dir/
|
||||
├── videos
|
||||
│ ├──dataset1_name
|
||||
│ │ ├──camera_view_dir
|
||||
│ │ ├── 0.mp4
|
||||
│ │ ├── 1.mp4
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset1_name
|
||||
│ ├── meta_data
|
||||
│ ├── 0.h5
|
||||
│ ├── 1.h5
|
||||
│ └── ...
|
||||
└── dataset1_name.csv
|
||||
```
|
||||
## 🚴♂️ Training
|
||||
A. Our training strategy is outlined as follows:
|
||||
- **Step 1**: Fine-tune a video generation model as the world model using the [Open-X](https://robotics-transformer-x.github.io/) dataset;
|
||||
- **Step 2**: Post-train $\text{UnifoLM-WMA}$ in decision-making mode on the downstream task dataset;
|
||||
<div align="left">
|
||||
<img src="assets/pngs/dm_mode.png" width="600">
|
||||
</div>
|
||||
- **Step 3**: Post-train $\text{UnifoLM-WMA}$ in simulation mode on the downstream task dataset.
|
||||
<div align="left">
|
||||
<img src="assets/pngs/sim_mode.png" width="600">
|
||||
</div>
|
||||
**Note**: If you only require $\text{UnifoLM-WMA}$ to operate in a single mode, you may skip the corresponding step.
|
||||
|
||||
B. To conduct training on a single or multiple datasets, please follow the steps below:
|
||||
- **Step 1**: The maximum DoF is assumed to be 16, if you have more than 16 DoF, update ```agent_state_dim``` and ```agent_action_dim``` in [configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) ;
|
||||
- **Step 2**: Set up the input shapes for each modality in [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json);
|
||||
- **Step 3**: Configure the training parameters in [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml). For the ```pretrained_checkpoint```, we recommend using the checkpoint " $\text{UnifoLM-WMA-0}_{Base}$ " fine-tuned on the [Open-X](https://robotics-transformer-x.github.io/) dataset;
|
||||
```yaml
|
||||
model:
|
||||
pretrained_checkpoint: /path/to/pretrained/checkpoint;
|
||||
...
|
||||
decision_making_only: True # Train the world model only in decision-making mode. If False, jointly train it in both decision-making and simulation modes.
|
||||
...
|
||||
data:
|
||||
...
|
||||
train:
|
||||
...
|
||||
data_dir: /path/to/training/dataset/directory
|
||||
dataset_and_weights: # list the name of each dataset below and make sure the summation of weights is 1.0
|
||||
dataset1_name: 0.2
|
||||
dataset2_name: 0.2
|
||||
dataset3_name: 0.2
|
||||
dataset4_name: 0.2
|
||||
dataset5_name: 0.2
|
||||
```
|
||||
- **Step 4**: Setup ```experiment_name```, ```save_root``` variables in [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh);
|
||||
- **Step 5**: Launch the training with the command:
|
||||
```
|
||||
bash scripts/train.sh
|
||||
```
|
||||
## 🌏 Inference under Interactive Simulation Mode
|
||||
To run the world model in an interactive simulation mode, follow these steps:
|
||||
- **Step 1**: (Skip this step if you just would like to test using the examples we provided) Prepare your own prompt following the format used in the [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts):
|
||||
```
|
||||
world_model_interaction_prompts/
|
||||
├── images
|
||||
│ ├── dataset1_name
|
||||
│ │ ├── 0.png # Image prompt
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset1_name
|
||||
│ │ ├── meta_data # Used for normalization
|
||||
│ │ ├── 0.h # Robot state and action data; in interaction mode,
|
||||
│ │ │ # only used to retrieve the robot state corresponding
|
||||
│ │ │ # to the image prompt
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── dataset1_name.csv # File for loading image prompts, text instruction and corresponding robot states
|
||||
└── ...
|
||||
```
|
||||
- **Step 2**: Specify the correct paths for ```pretrained_checkpoint```(e.g, $\text{UnifoLM-WMA-0}_{Dual}$) and ```data_dir``` in [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml)
|
||||
- **Step 3**: Set the paths for ```checkpoint```, ```res_dir``` and ```prompt_dir``` in [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh), and specify all the dataset's name in ```datasets=(...)```. Then, launch the inference with the command:
|
||||
```
|
||||
bash scripts/run_world_model_interaction.sh
|
||||
```
|
||||
|
||||
## 🧠 Inference and Deployment under Decision-Making Mode
|
||||
|
||||
In this setup, inference is performed on a server, while a robot client gathers observations from the real-robot and sends them to the server to query actions. The process unfolds through the following steps:
|
||||
|
||||
### Server Setup:
|
||||
- **Step-1**: Specify ```ckpt```, ```res_dir```, ```datasets``` in [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh);
|
||||
- **Step-2**: Configure ```data_dir``` and ```dataset_and_weights``` in [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225);
|
||||
- **Step-3**: Launch the server:
|
||||
```
|
||||
conda activate unifolm-wma
|
||||
cd unifolm-world-model-action
|
||||
bash scripts/run_real_eval_server.sh
|
||||
```
|
||||
|
||||
### Client Setup
|
||||
- **Step-1**: Follow the instructions in [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md) to create the ```unitree_deploy``` conda environment, install the required packages, launch the controllers or services on the real-robot.
|
||||
- **Step-2**: Open a new terminal and establish a tunnel connection from the client to the server:
|
||||
```
|
||||
ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
|
||||
```
|
||||
- **Step-3**: Run the ```unitree_deploy/robot_client.py``` script to start inference:
|
||||
```
|
||||
cd unitree_deploy
|
||||
python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
|
||||
```
|
||||
|
||||
## 📝 Codebase Architecture
|
||||
Here's a high-level overview of the project's code structure and core components:
|
||||
```
|
||||
unitree-world-model/
|
||||
├── assets # Media assets such as GIFs, images, and demo videos
|
||||
├── configs # Configuration files for training and inference
|
||||
│ ├── inference
|
||||
│ └── train
|
||||
├── examples # Example inputs and prompts for running inference
|
||||
├── external # External packages
|
||||
├── prepare_data # Scripts for dataset preprocessing and format conversion
|
||||
├── scripts # Main scripts for training, evaluation, and deployment
|
||||
├── src
|
||||
│ ├──unitree_worldmodel # Core Python package for the Unitree world model
|
||||
│ │ ├── data # Dataset loading, transformations, and dataloaders
|
||||
│ │ ├── models # Model architectures and backbone definitions
|
||||
│ │ ├── modules # Custom model modules and components
|
||||
│ │ └── utils # Utility functions and common helpers
|
||||
└── unitree_deploy # Deployment code
|
||||
```
|
||||
|
||||
## 🙏 Acknowledgement
|
||||
Lots of code are inherited from [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) and [HPT](https://github.com/liruiw/HPT).
|
||||
|
||||
## 📝 Citation
|
||||
```
|
||||
@misc{unifolm-wma-0,
|
||||
author = {Unitree},
|
||||
title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
|
||||
year = {2025},
|
||||
}
|
||||
```
|
||||
216
README_cn.md
Normal file
@@ -0,0 +1,216 @@
|
||||
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
|
||||
<p style="font-size: 1.2em;">
|
||||
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>项目主页</strong></a> |
|
||||
<a href="https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c"><strong>开源模型</strong></a> |
|
||||
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>开源数据</strong></a>
|
||||
</p>
|
||||
<div align="center">
|
||||
<p align="right">
|
||||
<span> 🌎English </span> | <a href="README_cn.md"> 🇨🇳中文 </a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
**UnifoLM-WMA-0** 是宇树科技跨多类机器人本体的开源世界模型-动作架构,专为通用机器人学习而设计。其核心成分在于一个可以理解机器人与环境交互物理规律的世界模型。该世界模型具备两大核心功能:(1)**仿真引擎**,作为交互式仿真器运行,为机器人学习提供合成数据;(2)**策略增强**,可与一个动作头进行对接,通过预测未来与物理世界的交互过程,进一步优化决策性能。模型的真机部署效果如下所示,其中右上角小窗口是世界模型对于未来环境变化的预测,可辅助控制指令生成。
|
||||
|
||||
## 🦾 真机效果
|
||||
|
||||
| <img src="assets/gifs/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
|:---:|:---:|
|
||||
| <img src="assets/gifs/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
|
||||
**注:右上角小窗口显示世界模型对未来动作视频的预测。**
|
||||
|
||||
## 新闻
|
||||
* 2025年9月22日: 🚀 我们发布了应用宇树科技机器人进行真机实验的部署代码.
|
||||
* 2025年9月15日: 🚀 我们发布了 **UnifoLM-WMA-0** 的训练与推理代码,以及对应的模型权重.
|
||||
|
||||
|
||||
## 📑 开源计划
|
||||
- [x] 训练代码
|
||||
- [x] 推理代码
|
||||
- [x] 模型Checkpoints
|
||||
- [x] 真机部署代码
|
||||
|
||||
## ⚙️ 安装
|
||||
```
|
||||
conda create -n unifolm-wma python==3.10.18
|
||||
conda activate unifolm-wma
|
||||
|
||||
conda install pinocchio=3.2.0 -c conda-forge -y
|
||||
conda install ffmpeg=7.1.1 -c conda-forge
|
||||
|
||||
git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
|
||||
|
||||
# If you already downloaded the repo:
|
||||
cd unifolm-world-model-action
|
||||
git submodule update --init --recursive
|
||||
|
||||
pip install -e .
|
||||
|
||||
cd external/dlimp
|
||||
pip install -e .
|
||||
```
|
||||
## 🧰 模型 Checkpoints
|
||||
| 模型 | 描述 | 链接 |
|
||||
|---------|-------|------|
|
||||
|$\text{UnifoLM-WMA-0}_{Base}$| 在 [Open-X](https://robotics-transformer-x.github.io/) 数据集微调后的模型 | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
|
||||
|$\text{UnifoLM-WMA-0}_{Dual}$| 在五个[宇树科技开源数据集](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab)上,决策和仿真双模式,联合微调后的模型 | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
|
||||
|
||||
## 🛢️ 数据集
|
||||
实验中,我们训练测试了如下五个开源数据集:
|
||||
| 数据集 | 机器人 | 链接 |
|
||||
|---------|-------|------|
|
||||
|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
|
||||
|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
|
||||
|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
|
||||
|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
|
||||
|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
|
||||
|
||||
要在自定义数据集上训练,请首先确保数据符合 [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) 数据集格式,假设下载后的数据目录结构如下:
|
||||
```
|
||||
source_dir/
|
||||
├── dataset1_name
|
||||
├── dataset2_name
|
||||
├── dataset3_name
|
||||
└── ...
|
||||
```
|
||||
随后执行以下命令进行格式转换:
|
||||
```python
|
||||
cd prepare_data
|
||||
python prepare_training_data.py \
|
||||
--source_dir /path/to/your/source_dir \
|
||||
--target_dir /path/to/save/the/converted/data/directory \
|
||||
--dataset_name "dataset1_name" \
|
||||
--robot_name "a tag of the robot in the dataset" # 例如: Unitree Z1 Robot Arm 或 Unitree G1 Robot with Gripper。
|
||||
```
|
||||
转换后的数据结构如下(注:模型训练只支持主视角相机输入, 如数据存在腕部视角,需删除CSV文件中```data_dir```列对应的视频路径):
|
||||
```
|
||||
target_dir/
|
||||
├── videos
|
||||
│ ├──dataset1_name
|
||||
│ │ ├──camera_view_dir
|
||||
│ │ ├── 0.mp4
|
||||
│ │ ├── 1.mp4
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset1_name
|
||||
│ │ ├── meta_data
|
||||
│ │ ├── 0.h5
|
||||
│ │ ├── 1.h5
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
└── dataset1_name.csv
|
||||
```
|
||||
## 🚴 ♂️ 模型训练
|
||||
一. 我们的训练策略概括如下:
|
||||
- **步骤 1**:在 [Open-X](https://robotics-transformer-x.github.io/) 数据集上微调视频生成模型,使其作为世界模型(World Model);
|
||||
- **步骤 2**:在下游任务数据集上,对 $\text{UnifoLM-WMA}$ 进行决策模式(decision-making mode)后训练;
|
||||
<div align="left">
|
||||
<img src="assets/pngs/dm_mode.png" width="600">
|
||||
</div>
|
||||
- **步骤 3**:在下游任务数据集上,对 $\text{UnifoLM-WMA}$ 进行仿真模式(simulation mode)后训练。
|
||||
<div align="left">
|
||||
<img src="assets/pngs/sim_mode.png" width="600">
|
||||
</div>
|
||||
**注意**:如果只需要 $\text{UnifoLM-WMA}$ 在单一模式下运行,可以跳过相应的步骤。
|
||||
|
||||
二. 在单个或多个数据集上进行训练,请按照以下步骤操作:
|
||||
- **步骤1**:默认的最高自由度为16DOF,若需更多自由度,请修改[configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) 中 ```agent_state_dim``` 及 ```agent_action_dim``` 的数值;
|
||||
- **步骤2**:在 [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json) 中为每种模态设置输入维度;
|
||||
- **步骤3**: 在 [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml) 中配置训练参数及路径。关于预训练的模型,推荐使用 $\text{UnifoLM-WMA-0}_{Base}$ ,其在[Open-X](https://robotics-transformer-x.github.io/) 数据集上微调过;
|
||||
```yaml
|
||||
model:
|
||||
pretrained_checkpoint: /path/to/pretrained/checkpoint
|
||||
...
|
||||
dicision_making_only: True # 是否只训练世界模型决策模式?如果否,则决策模式与仿真模式联合训练。
|
||||
...
|
||||
data:
|
||||
...
|
||||
train:
|
||||
...
|
||||
data_dir: /path/to/training/dataset/directory
|
||||
dataset_and_weights: # 列出所有数据集的名称及权重,确保权重和为1.0
|
||||
dataset1_name: 0.2
|
||||
dataset2_name: 0.2
|
||||
dataset3_name: 0.2
|
||||
dataset4_name: 0.2
|
||||
dataset5_name: 0.2
|
||||
```
|
||||
- **步骤4**: 在 [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh) 中配置```experiment_name```, ```save_root``` 变量;
|
||||
- **步骤5**: 运行如下指令开启训练:
|
||||
```
|
||||
bash scripts/train.sh
|
||||
```
|
||||
## 🌏 世界模型交互推理
|
||||
要启用世界模型的交互模式,请按以下步骤操作:
|
||||
- **步骤1**:(若仅用提供的实例进行测试,可跳过此步) 请按照 [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts) 目录中的格式,自定义提示词目录:
|
||||
```
|
||||
world_model_interaction_prompts/
|
||||
├── images
|
||||
│ ├── dataset1_name
|
||||
│ │ ├── 0.png # 图像提示词
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset1_name
|
||||
│ │ ├── meta_data # 用于归一化
|
||||
│ │ ├── 0.h # 机器人状态、动作相关数据,在交互模式下仅用于获取与图像提示词对应的机器人状态
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── dataset1_name.csv # 该文件用于加载对应的:图像提示词、文本指令及机器人状态
|
||||
└── ...
|
||||
```
|
||||
- **步骤2**: 在 [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml) 中指定 ```pretrained_checkpoint``` (例如:$\text{UnifoLM-WMA-0}_{Dual}$) 和 ```data_dir``` 的正确路径;
|
||||
- **步骤3**: 在 [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh) 中指定```checkpoint```、```res_dir``` 和 ```prompt_dir```的正确路径,并在```datasets=(...)```中列出测试的数据集名称,然后用下述指令启动推理:
|
||||
```
|
||||
bash scripts/run_world_model_interaction.sh
|
||||
```
|
||||
|
||||
## 🧠 世界模型决策推理及部署
|
||||
在我们的系统中,推理在服务器端执行;机器人客户端从真实机器人收集观测信息并发送至服务器, 进行视频及动作推理。可通过如下步骤实现整个过程:
|
||||
|
||||
### 服务器端设置
|
||||
- **步骤1**: 在 [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh) 中指定 ```ckpt```、```res_dir```、```datasets```;
|
||||
- **步骤2**: 在 [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225) 中配置 ```data_dir``` 和 ```dataset_and_weights```;
|
||||
- **步骤3**: 启动服务器:
|
||||
```
|
||||
conda activate unifolm-wma
|
||||
cd unifolm-world-model-action
|
||||
bash scripts/run_real_eval_server.sh
|
||||
```
|
||||
|
||||
### 客户端设置
|
||||
- **步骤1**: 参考 [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md),创建 ```unitree_deploy``` conda 环境,安装所需依赖包,并在真实机器人端启动控制器或服务;
|
||||
- **步骤2**: 打开一个新的终端,从客户端到服务器建立隧道连接:
|
||||
```
|
||||
ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
|
||||
```
|
||||
- **步骤3**: 运行 ```unitree_deploy/robot_client.py``` 脚本以启动推理:
|
||||
```
|
||||
cd unitree_deploy
|
||||
python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
|
||||
```
|
||||
|
||||
## 📝 代码架构
|
||||
以下是本项目代码结构设计及核心组件说明::
|
||||
```
|
||||
unitree-world-model/
|
||||
├── assets # GIF动图、静态图片和演示视频等媒体素材
|
||||
├── configs # 配置文件
|
||||
│ ├── inference
|
||||
│ └── train
|
||||
├── examples # 示例数据
|
||||
├── external # 外部代码库
|
||||
├── prepare_data # 数据处理
|
||||
├── scripts # 主程序脚本
|
||||
├── src
|
||||
│ ├──unitree_worldmodel # 核心库
|
||||
│ │ ├── data # 数据加载
|
||||
│ │ ├── models # 模型架构
|
||||
│ │ ├── modules # 自定义模块
|
||||
| │ └── utils # 工具函数
|
||||
```
|
||||
|
||||
## 🙏 致谢声明
|
||||
本项目代码基于以下优秀开源项目构建,特此致谢:[DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) 和 [HPT](https://github.com/liruiw/HPT).
|
||||
BIN
assets/pngs/dm_mode.png
Normal file
|
After Width: | Height: | Size: 1.2 MiB |
BIN
assets/pngs/sim_mode.png
Normal file
|
After Width: | Height: | Size: 1.3 MiB |
54
ckpts/.gitattributes
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
||||
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
*.db* filter=lfs diff=lfs merge=lfs -text
|
||||
*.ark* filter=lfs diff=lfs merge=lfs -text
|
||||
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
||||
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
||||
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
*.gguf* filter=lfs diff=lfs merge=lfs -text
|
||||
*.ggml filter=lfs diff=lfs merge=lfs -text
|
||||
*.llamafile* filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
assets/real_cleanup_pencils.gif filter=lfs diff=lfs merge=lfs -text
|
||||
assets/world_model_interaction.gif filter=lfs diff=lfs merge=lfs -text
|
||||
assets/real_dual_stackbox.gif filter=lfs diff=lfs merge=lfs -text
|
||||
assets/real_g1_pack_camera.gif filter=lfs diff=lfs merge=lfs -text
|
||||
assets/real_z1_stackbox.gif filter=lfs diff=lfs merge=lfs -text
|
||||
unifolm_wma_dual.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
439
ckpts/LICENSE
Normal file
@@ -0,0 +1,439 @@
|
||||
Attribution-NonCommercial-ShareAlike 4.0 International
|
||||
|
||||
Copyright (c) 2016-2025 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics")
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
||||
does not provide legal services or legal advice. Distribution of
|
||||
Creative Commons public licenses does not create a lawyer-client or
|
||||
other relationship. Creative Commons makes its licenses and related
|
||||
information available on an "as-is" basis. Creative Commons gives no
|
||||
warranties regarding its licenses, any material licensed under their
|
||||
terms and conditions, or any related information. Creative Commons
|
||||
disclaims all liability for damages resulting from their use to the
|
||||
fullest extent possible.
|
||||
|
||||
Using Creative Commons Public Licenses
|
||||
|
||||
Creative Commons public licenses provide a standard set of terms and
|
||||
conditions that creators and other rights holders may use to share
|
||||
original works of authorship and other material subject to copyright
|
||||
and certain other rights specified in the public license below. The
|
||||
following considerations are for informational purposes only, are not
|
||||
exhaustive, and do not form part of our licenses.
|
||||
|
||||
Considerations for licensors: Our public licenses are
|
||||
intended for use by those authorized to give the public
|
||||
permission to use material in ways otherwise restricted by
|
||||
copyright and certain other rights. Our licenses are
|
||||
irrevocable. Licensors should read and understand the terms
|
||||
and conditions of the license they choose before applying it.
|
||||
Licensors should also secure all rights necessary before
|
||||
applying our licenses so that the public can reuse the
|
||||
material as expected. Licensors should clearly mark any
|
||||
material not subject to the license. This includes other CC-
|
||||
licensed material, or material used under an exception or
|
||||
limitation to copyright. More considerations for licensors:
|
||||
wiki.creativecommons.org/Considerations_for_licensors
|
||||
|
||||
Considerations for the public: By using one of our public
|
||||
licenses, a licensor grants the public permission to use the
|
||||
licensed material under specified terms and conditions. If
|
||||
the licensor's permission is not necessary for any reason--for
|
||||
example, because of any applicable exception or limitation to
|
||||
copyright--then that use is not regulated by the license. Our
|
||||
licenses grant only permissions under copyright and certain
|
||||
other rights that a licensor has authority to grant. Use of
|
||||
the licensed material may still be restricted for other
|
||||
reasons, including because others have copyright or other
|
||||
rights in the material. A licensor may make special requests,
|
||||
such as asking that all changes be marked or described.
|
||||
Although not required by our licenses, you are encouraged to
|
||||
respect those requests where reasonable. More considerations
|
||||
for the public:
|
||||
wiki.creativecommons.org/Considerations_for_licensees
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
||||
Public License
|
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree
|
||||
to be bound by the terms and conditions of this Creative Commons
|
||||
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
||||
("Public License"). To the extent this Public License may be
|
||||
interpreted as a contract, You are granted the Licensed Rights in
|
||||
consideration of Your acceptance of these terms and conditions, and the
|
||||
Licensor grants You such rights in consideration of benefits the
|
||||
Licensor receives from making the Licensed Material available under
|
||||
these terms and conditions.
|
||||
|
||||
|
||||
Section 1 -- Definitions.
|
||||
|
||||
a. Adapted Material means material subject to Copyright and Similar
|
||||
Rights that is derived from or based upon the Licensed Material
|
||||
and in which the Licensed Material is translated, altered,
|
||||
arranged, transformed, or otherwise modified in a manner requiring
|
||||
permission under the Copyright and Similar Rights held by the
|
||||
Licensor. For purposes of this Public License, where the Licensed
|
||||
Material is a musical work, performance, or sound recording,
|
||||
Adapted Material is always produced where the Licensed Material is
|
||||
synched in timed relation with a moving image.
|
||||
|
||||
b. Adapter's License means the license You apply to Your Copyright
|
||||
and Similar Rights in Your contributions to Adapted Material in
|
||||
accordance with the terms and conditions of this Public License.
|
||||
|
||||
c. BY-NC-SA Compatible License means a license listed at
|
||||
creativecommons.org/compatiblelicenses, approved by Creative
|
||||
Commons as essentially the equivalent of this Public License.
|
||||
|
||||
d. Copyright and Similar Rights means copyright and/or similar rights
|
||||
closely related to copyright including, without limitation,
|
||||
performance, broadcast, sound recording, and Sui Generis Database
|
||||
Rights, without regard to how the rights are labeled or
|
||||
categorized. For purposes of this Public License, the rights
|
||||
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||
Rights.
|
||||
|
||||
e. Effective Technological Measures means those measures that, in the
|
||||
absence of proper authority, may not be circumvented under laws
|
||||
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||
Treaty adopted on December 20, 1996, and/or similar international
|
||||
agreements.
|
||||
|
||||
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||
any other exception or limitation to Copyright and Similar Rights
|
||||
that applies to Your use of the Licensed Material.
|
||||
|
||||
g. License Elements means the license attributes listed in the name
|
||||
of a Creative Commons Public License. The License Elements of this
|
||||
Public License are Attribution, NonCommercial, and ShareAlike.
|
||||
|
||||
h. Licensed Material means the artistic or literary work, database,
|
||||
or other material to which the Licensor applied this Public
|
||||
License.
|
||||
|
||||
i. Licensed Rights means the rights granted to You subject to the
|
||||
terms and conditions of this Public License, which are limited to
|
||||
all Copyright and Similar Rights that apply to Your use of the
|
||||
Licensed Material and that the Licensor has authority to license.
|
||||
|
||||
j. Licensor means the individual(s) or entity(ies) granting rights
|
||||
under this Public License.
|
||||
|
||||
k. NonCommercial means not primarily intended for or directed towards
|
||||
commercial advantage or monetary compensation. For purposes of
|
||||
this Public License, the exchange of the Licensed Material for
|
||||
other material subject to Copyright and Similar Rights by digital
|
||||
file-sharing or similar means is NonCommercial provided there is
|
||||
no payment of monetary compensation in connection with the
|
||||
exchange.
|
||||
|
||||
l. Share means to provide material to the public by any means or
|
||||
process that requires permission under the Licensed Rights, such
|
||||
as reproduction, public display, public performance, distribution,
|
||||
dissemination, communication, or importation, and to make material
|
||||
available to the public including in ways that members of the
|
||||
public may access the material from a place and at a time
|
||||
individually chosen by them.
|
||||
|
||||
m. Sui Generis Database Rights means rights other than copyright
|
||||
resulting from Directive 96/9/EC of the European Parliament and of
|
||||
the Council of 11 March 1996 on the legal protection of databases,
|
||||
as amended and/or succeeded, as well as other essentially
|
||||
equivalent rights anywhere in the world.
|
||||
|
||||
n. You means the individual or entity exercising the Licensed Rights
|
||||
under this Public License. Your has a corresponding meaning.
|
||||
|
||||
|
||||
Section 2 -- Scope.
|
||||
|
||||
a. License grant.
|
||||
|
||||
1. Subject to the terms and conditions of this Public License,
|
||||
the Licensor hereby grants You a worldwide, royalty-free,
|
||||
non-sublicensable, non-exclusive, irrevocable license to
|
||||
exercise the Licensed Rights in the Licensed Material to:
|
||||
|
||||
a. reproduce and Share the Licensed Material, in whole or
|
||||
in part, for NonCommercial purposes only; and
|
||||
|
||||
b. produce, reproduce, and Share Adapted Material for
|
||||
NonCommercial purposes only.
|
||||
|
||||
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||
Exceptions and Limitations apply to Your use, this Public
|
||||
License does not apply, and You do not need to comply with
|
||||
its terms and conditions.
|
||||
|
||||
3. Term. The term of this Public License is specified in Section
|
||||
6(a).
|
||||
|
||||
4. Media and formats; technical modifications allowed. The
|
||||
Licensor authorizes You to exercise the Licensed Rights in
|
||||
all media and formats whether now known or hereafter created,
|
||||
and to make technical modifications necessary to do so. The
|
||||
Licensor waives and/or agrees not to assert any right or
|
||||
authority to forbid You from making technical modifications
|
||||
necessary to exercise the Licensed Rights, including
|
||||
technical modifications necessary to circumvent Effective
|
||||
Technological Measures. For purposes of this Public License,
|
||||
simply making modifications authorized by this Section 2(a)
|
||||
(4) never produces Adapted Material.
|
||||
|
||||
5. Downstream recipients.
|
||||
|
||||
a. Offer from the Licensor -- Licensed Material. Every
|
||||
recipient of the Licensed Material automatically
|
||||
receives an offer from the Licensor to exercise the
|
||||
Licensed Rights under the terms and conditions of this
|
||||
Public License.
|
||||
|
||||
b. Additional offer from the Licensor -- Adapted Material.
|
||||
Every recipient of Adapted Material from You
|
||||
automatically receives an offer from the Licensor to
|
||||
exercise the Licensed Rights in the Adapted Material
|
||||
under the conditions of the Adapter's License You apply.
|
||||
|
||||
c. No downstream restrictions. You may not offer or impose
|
||||
any additional or different terms or conditions on, or
|
||||
apply any Effective Technological Measures to, the
|
||||
Licensed Material if doing so restricts exercise of the
|
||||
Licensed Rights by any recipient of the Licensed
|
||||
Material.
|
||||
|
||||
6. No endorsement. Nothing in this Public License constitutes or
|
||||
may be construed as permission to assert or imply that You
|
||||
are, or that Your use of the Licensed Material is, connected
|
||||
with, or sponsored, endorsed, or granted official status by,
|
||||
the Licensor or others designated to receive attribution as
|
||||
provided in Section 3(a)(1)(A)(i).
|
||||
|
||||
b. Other rights.
|
||||
|
||||
1. Moral rights, such as the right of integrity, are not
|
||||
licensed under this Public License, nor are publicity,
|
||||
privacy, and/or other similar personality rights; however, to
|
||||
the extent possible, the Licensor waives and/or agrees not to
|
||||
assert any such rights held by the Licensor to the limited
|
||||
extent necessary to allow You to exercise the Licensed
|
||||
Rights, but not otherwise.
|
||||
|
||||
2. Patent and trademark rights are not licensed under this
|
||||
Public License.
|
||||
|
||||
3. To the extent possible, the Licensor waives any right to
|
||||
collect royalties from You for the exercise of the Licensed
|
||||
Rights, whether directly or through a collecting society
|
||||
under any voluntary or waivable statutory or compulsory
|
||||
licensing scheme. In all other cases the Licensor expressly
|
||||
reserves any right to collect such royalties, including when
|
||||
the Licensed Material is used other than for NonCommercial
|
||||
purposes.
|
||||
|
||||
|
||||
Section 3 -- License Conditions.
|
||||
|
||||
Your exercise of the Licensed Rights is expressly made subject to the
|
||||
following conditions.
|
||||
|
||||
a. Attribution.
|
||||
|
||||
1. If You Share the Licensed Material (including in modified
|
||||
form), You must:
|
||||
|
||||
a. retain the following if it is supplied by the Licensor
|
||||
with the Licensed Material:
|
||||
|
||||
i. identification of the creator(s) of the Licensed
|
||||
Material and any others designated to receive
|
||||
attribution, in any reasonable manner requested by
|
||||
the Licensor (including by pseudonym if
|
||||
designated);
|
||||
|
||||
ii. a copyright notice;
|
||||
|
||||
iii. a notice that refers to this Public License;
|
||||
|
||||
iv. a notice that refers to the disclaimer of
|
||||
warranties;
|
||||
|
||||
v. a URI or hyperlink to the Licensed Material to the
|
||||
extent reasonably practicable;
|
||||
|
||||
b. indicate if You modified the Licensed Material and
|
||||
retain an indication of any previous modifications; and
|
||||
|
||||
c. indicate the Licensed Material is licensed under this
|
||||
Public License, and include the text of, or the URI or
|
||||
hyperlink to, this Public License.
|
||||
|
||||
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||
reasonable manner based on the medium, means, and context in
|
||||
which You Share the Licensed Material. For example, it may be
|
||||
reasonable to satisfy the conditions by providing a URI or
|
||||
hyperlink to a resource that includes the required
|
||||
information.
|
||||
3. If requested by the Licensor, You must remove any of the
|
||||
information required by Section 3(a)(1)(A) to the extent
|
||||
reasonably practicable.
|
||||
|
||||
b. ShareAlike.
|
||||
|
||||
In addition to the conditions in Section 3(a), if You Share
|
||||
Adapted Material You produce, the following conditions also apply.
|
||||
|
||||
1. The Adapter's License You apply must be a Creative Commons
|
||||
license with the same License Elements, this version or
|
||||
later, or a BY-NC-SA Compatible License.
|
||||
|
||||
2. You must include the text of, or the URI or hyperlink to, the
|
||||
Adapter's License You apply. You may satisfy this condition
|
||||
in any reasonable manner based on the medium, means, and
|
||||
context in which You Share Adapted Material.
|
||||
|
||||
3. You may not offer or impose any additional or different terms
|
||||
or conditions on, or apply any Effective Technological
|
||||
Measures to, Adapted Material that restrict exercise of the
|
||||
rights granted under the Adapter's License You apply.
|
||||
|
||||
|
||||
Section 4 -- Sui Generis Database Rights.
|
||||
|
||||
Where the Licensed Rights include Sui Generis Database Rights that
|
||||
apply to Your use of the Licensed Material:
|
||||
|
||||
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||
to extract, reuse, reproduce, and Share all or a substantial
|
||||
portion of the contents of the database for NonCommercial purposes
|
||||
only;
|
||||
|
||||
b. if You include all or a substantial portion of the database
|
||||
contents in a database in which You have Sui Generis Database
|
||||
Rights, then the database in which You have Sui Generis Database
|
||||
Rights (but not its individual contents) is Adapted Material,
|
||||
including for purposes of Section 3(b); and
|
||||
|
||||
c. You must comply with the conditions in Section 3(a) if You Share
|
||||
all or a substantial portion of the contents of the database.
|
||||
|
||||
For the avoidance of doubt, this Section 4 supplements and does not
|
||||
replace Your obligations under this Public License where the Licensed
|
||||
Rights include other Copyright and Similar Rights.
|
||||
|
||||
|
||||
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||
|
||||
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||
|
||||
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||
|
||||
c. The disclaimer of warranties and limitation of liability provided
|
||||
above shall be interpreted in a manner that, to the extent
|
||||
possible, most closely approximates an absolute disclaimer and
|
||||
waiver of all liability.
|
||||
|
||||
|
||||
Section 6 -- Term and Termination.
|
||||
|
||||
a. This Public License applies for the term of the Copyright and
|
||||
Similar Rights licensed here. However, if You fail to comply with
|
||||
this Public License, then Your rights under this Public License
|
||||
terminate automatically.
|
||||
|
||||
b. Where Your right to use the Licensed Material has terminated under
|
||||
Section 6(a), it reinstates:
|
||||
|
||||
1. automatically as of the date the violation is cured, provided
|
||||
it is cured within 30 days of Your discovery of the
|
||||
violation; or
|
||||
|
||||
2. upon express reinstatement by the Licensor.
|
||||
|
||||
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||
right the Licensor may have to seek remedies for Your violations
|
||||
of this Public License.
|
||||
|
||||
c. For the avoidance of doubt, the Licensor may also offer the
|
||||
Licensed Material under separate terms or conditions or stop
|
||||
distributing the Licensed Material at any time; however, doing so
|
||||
will not terminate this Public License.
|
||||
|
||||
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||
License.
|
||||
|
||||
|
||||
Section 7 -- Other Terms and Conditions.
|
||||
|
||||
a. The Licensor shall not be bound by any additional or different
|
||||
terms or conditions communicated by You unless expressly agreed.
|
||||
|
||||
b. Any arrangements, understandings, or agreements regarding the
|
||||
Licensed Material not stated herein are separate from and
|
||||
independent of the terms and conditions of this Public License.
|
||||
|
||||
|
||||
Section 8 -- Interpretation.
|
||||
|
||||
a. For the avoidance of doubt, this Public License does not, and
|
||||
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||
conditions on any use of the Licensed Material that could lawfully
|
||||
be made without permission under this Public License.
|
||||
|
||||
b. To the extent possible, if any provision of this Public License is
|
||||
deemed unenforceable, it shall be automatically reformed to the
|
||||
minimum extent necessary to make it enforceable. If the provision
|
||||
cannot be reformed, it shall be severed from this Public License
|
||||
without affecting the enforceability of the remaining terms and
|
||||
conditions.
|
||||
|
||||
c. No term or condition of this Public License will be waived and no
|
||||
failure to comply consented to unless expressly agreed to by the
|
||||
Licensor.
|
||||
|
||||
d. Nothing in this Public License constitutes or may be interpreted
|
||||
as a limitation upon, or waiver of, any privileges and immunities
|
||||
that apply to the Licensor or You, including from the legal
|
||||
processes of any jurisdiction or authority.
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons is not a party to its public
|
||||
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
||||
its public licenses to material it publishes and in those instances
|
||||
will be considered the “Licensor.” The text of the Creative Commons
|
||||
public licenses is dedicated to the public domain under the CC0 Public
|
||||
Domain Dedication. Except for the limited purpose of indicating that
|
||||
material is shared under a Creative Commons public license or as
|
||||
otherwise permitted by the Creative Commons policies published at
|
||||
creativecommons.org/policies, Creative Commons does not authorize the
|
||||
use of the trademark "Creative Commons" or any other trademark or logo
|
||||
of Creative Commons without its prior written consent including,
|
||||
without limitation, in connection with any unauthorized modifications
|
||||
to any of its public licenses or any other arrangements,
|
||||
understandings, or agreements concerning use of licensed material. For
|
||||
the avoidance of doubt, this paragraph does not form part of the
|
||||
public licenses.
|
||||
|
||||
Creative Commons may be contacted at creativecommons.org.
|
||||
38
ckpts/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
---
|
||||
tags:
|
||||
- robotics
|
||||
---
|
||||
|
||||
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
|
||||
<p style="font-size: 1.2em;">
|
||||
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>Project Page</strong></a> |
|
||||
<a href="https://github.com/unitreerobotics/unifolm-world-model-action"><strong>Code</strong></a> |
|
||||
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>Dataset</strong></a>
|
||||
</p>
|
||||
<div align="center">
|
||||
<div align="justify">
|
||||
<b>UnifoLM-WMA-0</b> is Unitree‘s first open-source world-model–action architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) <b>Simulation Engine</b> – operates as an interactive simulator to generate synthetic data for robot learning; (b) <b>Policy Enhancement</b> – connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## 🦾 Real Robot Deployment
|
||||
| <img src="assets/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
|:---:|:---:|
|
||||
| <img src="assets/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
|
||||
**Note: the top-right window shows the world model’s prediction of future environmental changes.**
|
||||
|
||||
## License
|
||||
The model is released under the CC BY-NC-SA 4.0 license as found in the [LICENSE](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0/blob/main/LICENSE). You are responsible for ensuring that your use of Unitree AI Models complies with all applicable laws.
|
||||
|
||||
## Model Architecture
|
||||

|
||||
|
||||
## Citation
|
||||
```
|
||||
@misc{unifolm-wma-0,
|
||||
author = {Unitree},
|
||||
title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
|
||||
year = {2025},
|
||||
}
|
||||
```
|
||||
213
configs/inference/base_model_inference.yaml
Normal file
@@ -0,0 +1,213 @@
|
||||
model:
|
||||
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
|
||||
params:
|
||||
rescale_betas_zero_snr: True
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.012
|
||||
num_timesteps_cond: 1
|
||||
timesteps: 1000
|
||||
first_stage_key: video
|
||||
cond_stage_key: instruction
|
||||
cond_stage_trainable: False
|
||||
conditioning_key: hybrid
|
||||
image_size: [40, 64]
|
||||
channels: 4
|
||||
scale_by_std: False
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
uncond_type: 'empty_seq'
|
||||
use_dynamic_rescale: true
|
||||
base_scale: 0.7
|
||||
fps_condition_type: 'fps'
|
||||
perframe_ae: True
|
||||
freeze_embedder: True
|
||||
n_obs_steps_imagen: 1
|
||||
n_obs_steps_acting: 1
|
||||
agent_state_dim: 16
|
||||
agent_action_dim: 16
|
||||
|
||||
###################### DP Related
|
||||
input_pertub: 0.1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 2000
|
||||
num_epochs: 30000
|
||||
gradient_accumulate_every: 1
|
||||
use_scheduler: True
|
||||
dp_use_ema: True
|
||||
|
||||
dp_ema_config:
|
||||
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
|
||||
params:
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
|
||||
noise_scheduler_config:
|
||||
target: diffusers.DDIMScheduler
|
||||
params:
|
||||
num_train_timesteps: 1000
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
clip_sample: True
|
||||
set_alpha_to_one: True
|
||||
steps_offset: 0
|
||||
prediction_type: epsilon
|
||||
|
||||
dp_optimizer_config:
|
||||
target: torch.optim.AdamW
|
||||
params:
|
||||
lr: 1.0e-4
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
wma_config:
|
||||
target: unifolm_wma.modules.networks.wma_model.WMAModel
|
||||
params:
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions:
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
dropout: 0.1
|
||||
num_head_channels: 64
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_linear: true
|
||||
use_checkpoint: True
|
||||
temporal_conv: True
|
||||
temporal_attention: True
|
||||
temporal_selfatt_only: True
|
||||
use_relative_position: False
|
||||
use_causal_attention: False
|
||||
temporal_length: 16
|
||||
addition_attention: True
|
||||
image_cross_attention: True
|
||||
default_fs: 10
|
||||
fs_condition: True
|
||||
cross_attention_scale_learnable: False
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
num_stem_token: 16
|
||||
base_model_gen_only: True
|
||||
|
||||
unet_head_config:
|
||||
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
|
||||
params:
|
||||
input_dim: ${model.params.agent_action_dim}
|
||||
n_obs_steps: ${model.params.n_obs_steps_acting}
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [256, 512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
num_head_channels: ${model.params.wma_config.params.num_head_channels}
|
||||
horizon: ${model.params.wma_config.params.temporal_length}
|
||||
use_linear_attn: ${model.params.wma_config.params.use_linear}
|
||||
use_linear_act_proj: True
|
||||
act_proj_dim: 32
|
||||
cond_cross_attention: False
|
||||
context_dims: []
|
||||
image_size: ${model.params.image_size}
|
||||
imagen_cond_gradient: True
|
||||
last_frame_only: False
|
||||
use_imagen_mid_only: False
|
||||
use_z_only: False
|
||||
|
||||
obs_encoder_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
params:
|
||||
rgb_model_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
|
||||
params:
|
||||
name: resnet18
|
||||
weights: null
|
||||
resize_shape: null
|
||||
crop_shape: null
|
||||
random_crop: False
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: True
|
||||
use_spatial_softmax: True
|
||||
spatial_softmax_kp: 128
|
||||
|
||||
###################### Action Tokenization
|
||||
stem_process_config:
|
||||
target: unifolm_wma.modules.encoders.condition.SATokenProjector
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 1
|
||||
dim_head: 64
|
||||
heads: 16
|
||||
num_queries: ${model.params.wma_config.params.num_stem_token}
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
chunk_size: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
first_stage_config:
|
||||
target: unifolm_wma.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
img_cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
|
||||
params:
|
||||
freeze: true
|
||||
|
||||
image_proj_stage_config:
|
||||
target: unifolm_wma.modules.encoders.resampler.Resampler
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 4
|
||||
dim_head: 64
|
||||
heads: 12
|
||||
num_queries: 16
|
||||
embedding_dim: 1280
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
normalization_config:
|
||||
input_shapes:
|
||||
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
input_normalization_modes:
|
||||
observation.state: 'min_max'
|
||||
output_shapes:
|
||||
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
output_normalization_modes:
|
||||
action: 'min_max'
|
||||
240
configs/inference/world_model_decision_making.yaml
Normal file
@@ -0,0 +1,240 @@
|
||||
model:
|
||||
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
|
||||
params:
|
||||
rescale_betas_zero_snr: True
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.012
|
||||
num_timesteps_cond: 1
|
||||
timesteps: 1000
|
||||
first_stage_key: video
|
||||
cond_stage_key: instruction
|
||||
cond_stage_trainable: False
|
||||
conditioning_key: hybrid
|
||||
image_size: [40, 64]
|
||||
channels: 4
|
||||
scale_by_std: False
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
uncond_type: 'empty_seq'
|
||||
use_dynamic_rescale: true
|
||||
base_scale: 0.7
|
||||
fps_condition_type: 'fps'
|
||||
perframe_ae: True
|
||||
freeze_embedder: True
|
||||
n_obs_steps_imagen: 2
|
||||
n_obs_steps_acting: 2
|
||||
agent_state_dim: 16
|
||||
agent_action_dim: 16
|
||||
decision_making_only: True
|
||||
|
||||
###################### DP Related
|
||||
input_pertub: 0.1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 2000
|
||||
num_epochs: 30000
|
||||
gradient_accumulate_every: 1
|
||||
use_scheduler: True
|
||||
dp_use_ema: True
|
||||
|
||||
dp_ema_config:
|
||||
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
|
||||
params:
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
|
||||
noise_scheduler_config:
|
||||
target: diffusers.DDIMScheduler
|
||||
params:
|
||||
num_train_timesteps: 1000
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
clip_sample: True
|
||||
set_alpha_to_one: True
|
||||
steps_offset: 0
|
||||
prediction_type: epsilon
|
||||
|
||||
dp_optimizer_config:
|
||||
target: torch.optim.AdamW
|
||||
params:
|
||||
lr: 1.0e-4
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
wma_config:
|
||||
target: unifolm_wma.modules.networks.wma_model.WMAModel
|
||||
params:
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions:
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
dropout: 0.1
|
||||
num_head_channels: 64
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_linear: true
|
||||
use_checkpoint: True
|
||||
temporal_conv: True
|
||||
temporal_attention: True
|
||||
temporal_selfatt_only: True
|
||||
use_relative_position: False
|
||||
use_causal_attention: False
|
||||
temporal_length: 16
|
||||
addition_attention: True
|
||||
image_cross_attention: True
|
||||
default_fs: 10
|
||||
fs_condition: True
|
||||
cross_attention_scale_learnable: False
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
num_stem_token: 16
|
||||
base_model_gen_only: False
|
||||
|
||||
unet_head_config:
|
||||
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
|
||||
params:
|
||||
input_dim: ${model.params.agent_action_dim}
|
||||
n_obs_steps: ${model.params.n_obs_steps_acting}
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [256, 512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
num_head_channels: ${model.params.wma_config.params.num_head_channels}
|
||||
horizon: ${model.params.wma_config.params.temporal_length}
|
||||
use_linear_attn: ${model.params.wma_config.params.use_linear}
|
||||
use_linear_act_proj: True
|
||||
act_proj_dim: 32
|
||||
cond_cross_attention: False
|
||||
context_dims: []
|
||||
image_size: ${model.params.image_size}
|
||||
imagen_cond_gradient: True
|
||||
last_frame_only: False
|
||||
use_imagen_mid_only: False
|
||||
use_z_only: False
|
||||
|
||||
obs_encoder_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
params:
|
||||
rgb_model_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
|
||||
params:
|
||||
name: resnet18
|
||||
weights: null
|
||||
resize_shape: null
|
||||
crop_shape: null
|
||||
random_crop: False
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: True
|
||||
use_spatial_softmax: True
|
||||
spatial_softmax_kp: 128
|
||||
|
||||
###################### Action Tokenization
|
||||
stem_process_config:
|
||||
target: unifolm_wma.modules.encoders.condition.SATokenProjector
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 1
|
||||
dim_head: 64
|
||||
heads: 16
|
||||
num_queries: ${model.params.wma_config.params.num_stem_token}
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
chunk_size: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
first_stage_config:
|
||||
target: unifolm_wma.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
img_cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
|
||||
params:
|
||||
freeze: true
|
||||
|
||||
image_proj_stage_config:
|
||||
target: unifolm_wma.modules.encoders.resampler.Resampler
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 4
|
||||
dim_head: 64
|
||||
heads: 12
|
||||
num_queries: 16
|
||||
embedding_dim: 1280
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
normalization_config:
|
||||
input_shapes:
|
||||
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
input_normalization_modes:
|
||||
observation.state: 'min_max'
|
||||
output_shapes:
|
||||
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
output_normalization_modes:
|
||||
action: 'min_max'
|
||||
|
||||
data:
|
||||
target: unifolm_wma.utils.data.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 6
|
||||
num_workers: 12
|
||||
wrap: False
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/path/to/the/dataset/directory/that/contains/the/meta/folder/of/the/testing/case/under/a/transitions/folder' # e.g., /path/to/unifolm-world-model-action/examples/world_model_interaction_prompts
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
resolution: [320, 512]
|
||||
spatial_transform: resize_center_crop
|
||||
crop_resolution: [320, 512]
|
||||
random_fs: False
|
||||
cond_robot_label_prob: 0.0
|
||||
normalization_mode: 'min_max'
|
||||
individual_normalization: True
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
max_action_dim: ${model.params.agent_action_dim}
|
||||
max_state_dim: ${model.params.agent_state_dim}
|
||||
dataset_and_weights:
|
||||
unitree_g1_pack_camera: 1.0
|
||||
244
configs/inference/world_model_interaction.yaml
Normal file
@@ -0,0 +1,244 @@
|
||||
model:
|
||||
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
|
||||
params:
|
||||
rescale_betas_zero_snr: True
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.012
|
||||
num_timesteps_cond: 1
|
||||
timesteps: 1000
|
||||
first_stage_key: video
|
||||
cond_stage_key: instruction
|
||||
cond_stage_trainable: False
|
||||
conditioning_key: hybrid
|
||||
image_size: [40, 64]
|
||||
channels: 4
|
||||
scale_by_std: False
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
uncond_type: 'empty_seq'
|
||||
use_dynamic_rescale: true
|
||||
base_scale: 0.7
|
||||
fps_condition_type: 'fps'
|
||||
perframe_ae: True
|
||||
freeze_embedder: True
|
||||
n_obs_steps_imagen: 2
|
||||
n_obs_steps_acting: 2
|
||||
agent_state_dim: 16
|
||||
agent_action_dim: 16
|
||||
decision_making_only: False
|
||||
|
||||
###################### DP Related
|
||||
input_pertub: 0.1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 2000
|
||||
num_epochs: 30000
|
||||
gradient_accumulate_every: 1
|
||||
use_scheduler: True
|
||||
dp_use_ema: True
|
||||
|
||||
dp_ema_config:
|
||||
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
|
||||
params:
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
|
||||
noise_scheduler_config:
|
||||
target: diffusers.DDIMScheduler
|
||||
params:
|
||||
num_train_timesteps: 1000
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
clip_sample: True
|
||||
set_alpha_to_one: True
|
||||
steps_offset: 0
|
||||
prediction_type: epsilon
|
||||
|
||||
dp_optimizer_config:
|
||||
target: torch.optim.AdamW
|
||||
params:
|
||||
lr: 1.0e-4
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
wma_config:
|
||||
target: unifolm_wma.modules.networks.wma_model.WMAModel
|
||||
params:
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions:
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
dropout: 0.1
|
||||
num_head_channels: 64
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_linear: true
|
||||
use_checkpoint: True
|
||||
temporal_conv: True
|
||||
temporal_attention: True
|
||||
temporal_selfatt_only: True
|
||||
use_relative_position: False
|
||||
use_causal_attention: False
|
||||
temporal_length: 16
|
||||
addition_attention: True
|
||||
image_cross_attention: True
|
||||
default_fs: 10
|
||||
fs_condition: True
|
||||
cross_attention_scale_learnable: False
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
num_stem_token: 16
|
||||
base_model_gen_only: False
|
||||
|
||||
unet_head_config:
|
||||
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
|
||||
params:
|
||||
input_dim: ${model.params.agent_action_dim}
|
||||
n_obs_steps: ${model.params.n_obs_steps_acting}
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [256, 512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
num_head_channels: ${model.params.wma_config.params.num_head_channels}
|
||||
horizon: ${model.params.wma_config.params.temporal_length}
|
||||
use_linear_attn: ${model.params.wma_config.params.use_linear}
|
||||
use_linear_act_proj: True
|
||||
act_proj_dim: 32
|
||||
cond_cross_attention: False
|
||||
context_dims: []
|
||||
image_size: ${model.params.image_size}
|
||||
imagen_cond_gradient: True
|
||||
last_frame_only: False
|
||||
use_imagen_mid_only: False
|
||||
use_z_only: False
|
||||
|
||||
obs_encoder_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
params:
|
||||
rgb_model_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
|
||||
params:
|
||||
name: resnet18
|
||||
weights: null
|
||||
resize_shape: null
|
||||
crop_shape: null
|
||||
random_crop: False
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: True
|
||||
use_spatial_softmax: True
|
||||
spatial_softmax_kp: 128
|
||||
|
||||
###################### Action Tokenization
|
||||
stem_process_config:
|
||||
target: unifolm_wma.modules.encoders.condition.SATokenProjector
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 1
|
||||
dim_head: 64
|
||||
heads: 16
|
||||
num_queries: ${model.params.wma_config.params.num_stem_token}
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
chunk_size: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
first_stage_config:
|
||||
target: unifolm_wma.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
img_cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
|
||||
params:
|
||||
freeze: true
|
||||
|
||||
image_proj_stage_config:
|
||||
target: unifolm_wma.modules.encoders.resampler.Resampler
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 4
|
||||
dim_head: 64
|
||||
heads: 12
|
||||
num_queries: 16
|
||||
embedding_dim: 1280
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
normalization_config:
|
||||
input_shapes:
|
||||
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
input_normalization_modes:
|
||||
observation.state: 'min_max'
|
||||
output_shapes:
|
||||
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
|
||||
output_normalization_modes:
|
||||
action: 'min_max'
|
||||
|
||||
data:
|
||||
target: unifolm_wma.utils.data.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 6
|
||||
num_workers: 12
|
||||
wrap: False
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
resolution: [320, 512]
|
||||
spatial_transform: resize_center_crop
|
||||
crop_resolution: [320, 512]
|
||||
random_fs: False
|
||||
cond_robot_label_prob: 0.0
|
||||
normalization_mode: 'min_max'
|
||||
individual_normalization: True
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
max_action_dim: ${model.params.agent_action_dim}
|
||||
max_state_dim: ${model.params.agent_state_dim}
|
||||
dataset_and_weights:
|
||||
unitree_z1_stackbox: 0.2
|
||||
unitree_z1_dual_arm_stackbox: 0.2
|
||||
unitree_z1_dual_arm_stackbox_v2: 0.2
|
||||
unitree_z1_dual_arm_cleanup_pencils: 0.2
|
||||
unitree_g1_pack_camera: 0.2
|
||||
287
configs/train/config.yaml
Normal file
@@ -0,0 +1,287 @@
|
||||
model:
|
||||
pretrained_checkpoint: /path/to/pretrained/checkpoint
|
||||
base_learning_rate: 1.0e-05
|
||||
scale_lr: False
|
||||
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
|
||||
params:
|
||||
rescale_betas_zero_snr: True
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.012
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: video
|
||||
cond_stage_key: instruction
|
||||
cond_stage_trainable: False
|
||||
image_proj_model_trainable: True
|
||||
conditioning_key: hybrid
|
||||
image_size: [40, 64]
|
||||
channels: 4
|
||||
scale_by_std: False
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
uncond_prob: 0.05
|
||||
uncond_type: 'empty_seq'
|
||||
rand_cond_frame: false
|
||||
use_dynamic_rescale: true
|
||||
base_scale: 0.7
|
||||
fps_condition_type: 'fps'
|
||||
perframe_ae: True
|
||||
freeze_embedder: True
|
||||
n_obs_steps_imagen: 2
|
||||
n_obs_steps_acting: 2
|
||||
agent_state_dim: 16
|
||||
agent_action_dim: 16
|
||||
decision_making_only: True
|
||||
|
||||
###################### DP Related
|
||||
input_pertub: 0.1
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 2000
|
||||
num_epochs: 60000
|
||||
gradient_accumulate_every: 1
|
||||
use_scheduler: True
|
||||
dp_use_ema: True
|
||||
|
||||
dp_ema_config:
|
||||
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
|
||||
params:
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
|
||||
noise_scheduler_config:
|
||||
target: diffusers.DDIMScheduler
|
||||
params:
|
||||
num_train_timesteps: 1000
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
clip_sample: True
|
||||
set_alpha_to_one: True
|
||||
steps_offset: 0
|
||||
prediction_type: epsilon
|
||||
|
||||
dp_optimizer_config:
|
||||
target: torch.optim.AdamW
|
||||
params:
|
||||
lr: 1.0e-4
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
wma_config:
|
||||
target: unifolm_wma.modules.networks.wma_model.WMAModel
|
||||
params:
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions:
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
dropout: 0.1
|
||||
num_head_channels: 64
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_linear: true
|
||||
use_checkpoint: True
|
||||
temporal_conv: True
|
||||
temporal_attention: True
|
||||
temporal_selfatt_only: True
|
||||
use_relative_position: False
|
||||
use_causal_attention: False
|
||||
temporal_length: 16
|
||||
addition_attention: True
|
||||
image_cross_attention: True
|
||||
default_fs: 10
|
||||
fs_condition: True
|
||||
cross_attention_scale_learnable: False
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
num_stem_token: 16
|
||||
base_model_gen_only: False
|
||||
|
||||
unet_head_config:
|
||||
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
|
||||
params:
|
||||
input_dim: ${model.params.agent_action_dim}
|
||||
n_obs_steps: ${model.params.n_obs_steps_acting}
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [256, 512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
num_head_channels: ${model.params.wma_config.params.num_head_channels}
|
||||
horizon: ${model.params.wma_config.params.temporal_length}
|
||||
use_linear_attn: ${model.params.wma_config.params.use_linear}
|
||||
use_linear_act_proj: True
|
||||
act_proj_dim: 32
|
||||
cond_cross_attention: False
|
||||
context_dims: []
|
||||
image_size: ${model.params.image_size}
|
||||
imagen_cond_gradient: True
|
||||
last_frame_only: False
|
||||
use_imagen_mid_only: False
|
||||
use_z_only: False
|
||||
|
||||
obs_encoder_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
params:
|
||||
rgb_model_config:
|
||||
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
|
||||
params:
|
||||
name: resnet18
|
||||
weights: null
|
||||
resize_shape: null
|
||||
crop_shape: null
|
||||
random_crop: False
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: True
|
||||
use_spatial_softmax: True
|
||||
spatial_softmax_kp: 128
|
||||
|
||||
###################### Action Tokenization
|
||||
stem_process_config:
|
||||
target: unifolm_wma.modules.encoders.condition.SATokenProjector
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 1
|
||||
dim_head: 64
|
||||
heads: 16
|
||||
num_queries: ${model.params.wma_config.params.num_stem_token}
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
chunk_size: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
first_stage_config:
|
||||
target: unifolm_wma.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
img_cond_stage_config:
|
||||
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
|
||||
params:
|
||||
freeze: true
|
||||
|
||||
image_proj_stage_config:
|
||||
target: unifolm_wma.modules.encoders.resampler.Resampler
|
||||
params:
|
||||
dim: 1024
|
||||
depth: 4
|
||||
dim_head: 64
|
||||
heads: 12
|
||||
num_queries: 16
|
||||
embedding_dim: 1280
|
||||
output_dim: 1024
|
||||
ff_mult: 4
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
|
||||
normalization_config:
|
||||
input_shapes:
|
||||
observation.state: ${model.params.wma_config.params.unet_head_config.params.input_dim}
|
||||
input_normalization_modes:
|
||||
observation.state: 'min_max'
|
||||
output_shapes:
|
||||
action: ${model.params.wma_config.params.unet_head_config.params.input_dim}
|
||||
output_normalization_modes:
|
||||
action: 'min_max'
|
||||
|
||||
data:
|
||||
target: unifolm_wma.utils.data.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 8
|
||||
num_workers: 12
|
||||
wrap: False
|
||||
train:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/path/to/training/dataset/directory'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
resolution: [320, 512]
|
||||
spatial_transform: resize_center_crop
|
||||
crop_resolution: [320, 512]
|
||||
random_fs: False
|
||||
cond_robot_label_prob: 0.0
|
||||
normalization_mode: 'min_max'
|
||||
individual_normalization: True
|
||||
n_obs_steps: ${model.params.n_obs_steps_imagen}
|
||||
max_action_dim: ${model.params.agent_action_dim}
|
||||
max_state_dim: ${model.params.agent_state_dim}
|
||||
dataset_and_weights:
|
||||
unitree_z1_stackbox: 0.2
|
||||
unitree_z1_dual_arm_stackbox: 0.2
|
||||
unitree_z1_dual_arm_stackbox_v2: 0.2
|
||||
unitree_z1_dual_arm_cleanup_pencils: 0.2
|
||||
unitree_g1_pack_camera: 0.2
|
||||
|
||||
lightning:
|
||||
precision: 16
|
||||
trainer:
|
||||
benchmark: True
|
||||
accumulate_grad_batches: 2
|
||||
max_steps: 300000
|
||||
log_every_n_steps: 50
|
||||
val_check_interval: 1.0
|
||||
gradient_clip_algorithm: 'norm'
|
||||
gradient_clip_val: 0.5
|
||||
enable_model_summary: False
|
||||
callbacks:
|
||||
model_checkpoint:
|
||||
target: pytorch_lightning.callbacks.ModelCheckpoint
|
||||
params:
|
||||
every_n_train_steps: 1000
|
||||
filename: "{epoch}-{step}"
|
||||
save_weights_only: True
|
||||
metrics_over_trainsteps_checkpoint:
|
||||
target: pytorch_lightning.callbacks.ModelCheckpoint
|
||||
params:
|
||||
filename: '{epoch}-{step}'
|
||||
save_weights_only: True
|
||||
every_n_train_steps: 10000
|
||||
batch_logger:
|
||||
target: unifolm_wma.utils.callbacks.ImageLogger
|
||||
params:
|
||||
batch_frequency: 20000
|
||||
to_local: False
|
||||
max_images: 8
|
||||
log_images_kwargs:
|
||||
ddim_steps: 16
|
||||
unconditional_guidance_scale: 1.0
|
||||
timestep_spacing: uniform_trailing
|
||||
guidance_rescale: 0.7
|
||||
BIN
examples/base_model_prompts/0.png
Normal file
|
After Width: | Height: | Size: 162 KiB |
BIN
examples/base_model_prompts/1.png
Normal file
|
After Width: | Height: | Size: 258 KiB |
BIN
examples/base_model_prompts/10.png
Normal file
|
After Width: | Height: | Size: 39 KiB |
BIN
examples/base_model_prompts/11.png
Normal file
|
After Width: | Height: | Size: 256 KiB |
BIN
examples/base_model_prompts/12.png
Normal file
|
After Width: | Height: | Size: 85 KiB |
BIN
examples/base_model_prompts/13.png
Normal file
|
After Width: | Height: | Size: 85 KiB |
BIN
examples/base_model_prompts/14.png
Normal file
|
After Width: | Height: | Size: 82 KiB |
BIN
examples/base_model_prompts/15.png
Normal file
|
After Width: | Height: | Size: 257 KiB |
BIN
examples/base_model_prompts/2.png
Normal file
|
After Width: | Height: | Size: 80 KiB |
BIN
examples/base_model_prompts/3.png
Normal file
|
After Width: | Height: | Size: 45 KiB |
BIN
examples/base_model_prompts/4.png
Normal file
|
After Width: | Height: | Size: 41 KiB |
BIN
examples/base_model_prompts/5.png
Normal file
|
After Width: | Height: | Size: 45 KiB |
BIN
examples/base_model_prompts/6.png
Normal file
|
After Width: | Height: | Size: 84 KiB |
BIN
examples/base_model_prompts/7.png
Normal file
|
After Width: | Height: | Size: 44 KiB |
BIN
examples/base_model_prompts/8.png
Normal file
|
After Width: | Height: | Size: 56 KiB |
BIN
examples/base_model_prompts/9.png
Normal file
|
After Width: | Height: | Size: 59 KiB |
17
examples/base_model_prompts/prompts.csv
Normal file
@@ -0,0 +1,17 @@
|
||||
videoid,instruction,fps,start_idx,fs,num_gen
|
||||
0,wash the pan,16.0,152.0,2.0,6.0
|
||||
1,pick up the blue cup and put it into the brown cup. ,5.0,0.0,2.0,4.0
|
||||
2,close top drawer,3.0,0.0,2.0,1.0
|
||||
3,Close the laptop.,10.0,0.0,2.0,4.0
|
||||
4,destack cube,5.0,0.0,2.0,3.0
|
||||
5,arrange plate and fork,20.0,40.0,2.0,4.0
|
||||
6,Place the lid on the teapot,15.0,30.0,1.0,4.0
|
||||
7,Pick up the green object and insert it.,10.0,0.0,2.0,4.0
|
||||
8,place the burger meat in the oven,10.0,0.0,2.0,2.0
|
||||
9,make a cup of coffee with the keurig machine,10.0,0.0,2.0,4.0
|
||||
10,assemble one_leg,10.0,0.0,2.0,7.0
|
||||
11,get the cloth and wipe up the spill under the wine glass,8.0,669.0,2.0,3.0
|
||||
12,palce dishes in the dish rack,10.0,0.0,2.0,4.0
|
||||
13,move redbull can near green can,3.0,3.0,2.0,1.0
|
||||
14,open the drawer,5.0,5.0,1.0,2.0
|
||||
15,sweep the green cloth to the left side of the table,5.0,0.0,2.0,3.0
|
||||
|
|
After Width: | Height: | Size: 153 KiB |
|
After Width: | Height: | Size: 134 KiB |
|
After Width: | Height: | Size: 286 KiB |
|
After Width: | Height: | Size: 161 KiB |
|
After Width: | Height: | Size: 101 KiB |
BIN
examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/0.h5
Executable file
@@ -0,0 +1,2 @@
|
||||
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
|
||||
0,x,x,unitree_g1_pack_camera,Pack black camera into box.,x,x,x,Unitree G1 Robot with Gripper,30
|
||||
|
@@ -0,0 +1,2 @@
|
||||
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
|
||||
0,x,x,unitree_z1_dual_arm_cleanup_pencils,clean up eraser and pencils,x,x,x,Unitree Z1 Robot Dual-Arm,30
|
||||
|
@@ -0,0 +1,2 @@
|
||||
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
|
||||
0,x,x,unitree_z1_dual_arm_stackbox,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top.",x,x,x,Unitree Z1 Robot Dual-Arm,30
|
||||
|
@@ -0,0 +1,2 @@
|
||||
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
|
||||
0,x,x,unitree_z1_dual_arm_stackbox_v2,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top",x,x,x,Unitree Z1 Robot Dual-Arm,30
|
||||
|
2
examples/world_model_interaction_prompts/unitree_z1_stackbox.csv
Executable file
@@ -0,0 +1,2 @@
|
||||
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
|
||||
0,x,x,unitree_z1_stackbox,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top.",x,x,x,Unitree Z1 Robot Arm,30
|
||||
|
1
external/dlimp
vendored
Submodule
1204
model_architecture_analysis.md
Normal file
199
prepare_data/prepare_training_data.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import h5py
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import torch
|
||||
import subprocess
|
||||
|
||||
from pathlib import Path
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
For example:
|
||||
```
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||
>>> print(flatten_dict(dct))
|
||||
{"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def is_av1(file_path):
|
||||
try:
|
||||
result = subprocess.run([
|
||||
"ffprobe", "-v", "error", "-select_streams", "v:0",
|
||||
"-show_entries", "stream=codec_name", "-of", "csv=p=0",
|
||||
str(file_path)
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True)
|
||||
return result.stdout.strip() == "av1"
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def convert_to_h264(input_path, output_path):
|
||||
subprocess.run([
|
||||
"ffmpeg", "-i",
|
||||
str(input_path), "-c:v", "libx264", "-preset", "slow", "-crf", "23",
|
||||
"-c:a", "copy",
|
||||
str(output_path)
|
||||
],
|
||||
check=True)
|
||||
|
||||
|
||||
def main(args):
|
||||
source_dir = Path(args.source_dir)
|
||||
source_data_dir = source_dir / args.dataset_name / "data" / "chunk-000"
|
||||
source_meta_dir = source_dir / args.dataset_name / "meta"
|
||||
source_videos_dir = source_dir / args.dataset_name / "videos" / "chunk-000"
|
||||
|
||||
target_dir = Path(args.target_dir)
|
||||
target_videos_dir = target_dir / "videos" / args.dataset_name
|
||||
target_transitions_dir = target_dir / "transitions" / args.dataset_name
|
||||
target_meta_dir = target_dir / "transitions" / args.dataset_name / "meta_data"
|
||||
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_transitions_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_meta_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
csv_file = target_dir / f"{args.dataset_name}.csv"
|
||||
COLUMNS = [
|
||||
'videoid', 'contentUrl', 'duration', 'data_dir', 'instruction',
|
||||
'dynamic_confidence', 'dynamic_wording', 'dynamic_source_category',
|
||||
'embodiment'
|
||||
]
|
||||
df = pd.DataFrame(columns=COLUMNS)
|
||||
|
||||
# Load info.json from source dir
|
||||
info_json_path = source_meta_dir / "info.json"
|
||||
with open(str(info_json_path), "r") as f:
|
||||
info = json.load(f)
|
||||
total_episodes = info['total_episodes']
|
||||
|
||||
# Load task.jsonl to get lanugage ins
|
||||
tasks_jsonl_path = source_meta_dir / "tasks.jsonl"
|
||||
with open(str(tasks_jsonl_path), "r") as f:
|
||||
tasks = [json.loads(line) for line in f]
|
||||
instruction = tasks[0]['task']
|
||||
|
||||
source_video_views = [d for d in source_videos_dir.iterdir()]
|
||||
for v_idx, source_view_dir in enumerate(source_video_views):
|
||||
|
||||
view_name = source_view_dir.name
|
||||
target_videos_view_dir = target_videos_dir / view_name
|
||||
target_videos_view_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if v_idx == 0:
|
||||
all_actions = []
|
||||
all_states = []
|
||||
|
||||
for idx in tqdm(range(total_episodes)):
|
||||
# Copy source video to target vidoe dir
|
||||
source_video = source_view_dir / f"episode_{idx:06d}.mp4"
|
||||
if is_av1(source_video):
|
||||
output_video = str(target_videos_view_dir / f"{idx}.mp4")
|
||||
print(f"Converting episode_{idx:06d}.mp4 to H.264...")
|
||||
convert_to_h264(source_video, output_video)
|
||||
else:
|
||||
print(f"Skipping episode_{idx:06d}.mp4: not AV1 encoded.")
|
||||
|
||||
# Load parquet file
|
||||
episode_parquet_file = source_data_dir / f"episode_{idx:06d}.parquet"
|
||||
episode_data = pd.read_parquet(episode_parquet_file)
|
||||
actions = torch.tensor(episode_data['action'].tolist())
|
||||
states = torch.tensor(episode_data['observation.state'].tolist())
|
||||
|
||||
# Save action and state into a h5 file
|
||||
if v_idx == 0:
|
||||
target_h5_file = target_transitions_dir / f"{idx}.h5"
|
||||
with h5py.File(str(target_h5_file), 'w') as h5f:
|
||||
h5f.create_dataset('observation.state', data=states)
|
||||
h5f.create_dataset('action', data=actions)
|
||||
h5f.attrs['action_type'] = 'joint position'
|
||||
h5f.attrs['state_type'] = 'joint position'
|
||||
h5f.attrs['robot_type'] = args.robot_name
|
||||
|
||||
# Updata df
|
||||
df = pd.concat([
|
||||
df,
|
||||
pd.DataFrame([{
|
||||
'videoid': idx,
|
||||
'contentUrl': 'x',
|
||||
'duration': 'x',
|
||||
'data_dir': args.dataset_name + f"/{view_name}",
|
||||
'instruction': instruction,
|
||||
'dynamic_confidence': 'x',
|
||||
'dynamic_wording': 'x',
|
||||
'dynamic_source_category': 'x',
|
||||
'embodiment': args.robot_name
|
||||
}])
|
||||
],
|
||||
ignore_index=True)
|
||||
|
||||
# Collect action and state
|
||||
if v_idx == 0:
|
||||
all_actions.append(actions)
|
||||
all_states.append(states)
|
||||
|
||||
# Create satas.safetensors
|
||||
actions = torch.cat(all_actions, dim=0)
|
||||
states = torch.cat(all_states, dim=0)
|
||||
|
||||
stats = {'action': {}, 'observation.state': {}}
|
||||
stats['action']['max'] = actions.max(dim=0).values
|
||||
stats['action']['min'] = actions.min(dim=0).values
|
||||
stats['action']['mean'] = actions.mean(dim=0)
|
||||
stats['action']['std'] = actions.std(dim=0)
|
||||
|
||||
stats['observation.state']['max'] = states.max(dim=0).values
|
||||
stats['observation.state']['min'] = states.min(dim=0).values
|
||||
stats['observation.state']['mean'] = states.mean(dim=0)
|
||||
stats['observation.state']['std'] = states.std(dim=0)
|
||||
|
||||
flattened_stats = flatten_dict(stats)
|
||||
target_stats_file = target_meta_dir / "stats.safetensors"
|
||||
save_file(flattened_stats, target_stats_file)
|
||||
|
||||
df.to_csv(csv_file, index=False)
|
||||
print(f">>> Finished create {args.dataset_name} dataset ...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--source_dir',
|
||||
action='store',
|
||||
type=str,
|
||||
help='The dataset dir under lerobot 2.0 data format.',
|
||||
required=True)
|
||||
parser.add_argument('--target_dir',
|
||||
action='store',
|
||||
type=str,
|
||||
default='./data',
|
||||
help='The target dir to save new formatted dataset.')
|
||||
parser.add_argument('--dataset_name',
|
||||
action='store',
|
||||
type=str,
|
||||
help='dataset name',
|
||||
required=True)
|
||||
parser.add_argument('--robot_name',
|
||||
action='store',
|
||||
type=str,
|
||||
help='robot name',
|
||||
required=True)
|
||||
main(parser.parse_args())
|
||||
53
pyproject.toml
Executable file
@@ -0,0 +1,53 @@
|
||||
[project]
|
||||
name = "unifolm_wma"
|
||||
version = "0.0.1"
|
||||
description = "UnifoLM-WMA-0"
|
||||
license = { text = "BSD-3-Clause" }
|
||||
authors = [
|
||||
{name="Unitree Embodied AI R&D Team", email="rd_xyc@unitree.com" }
|
||||
]
|
||||
requires-python = "==3.10.18"
|
||||
dependencies = [
|
||||
"decord==0.6.0",
|
||||
"einops==0.8.0",
|
||||
"imageio==2.35.1",
|
||||
"numpy==1.24.2",
|
||||
"omegaconf==2.3.0",
|
||||
"opencv-python==4.10.0.84",
|
||||
"pandas==2.0.0",
|
||||
"pillow==9.5.0",
|
||||
"pytorch-lightning==1.9.3",
|
||||
"pyyaml==6.0",
|
||||
"setuptools==65.6.3",
|
||||
"torch==2.3.1",
|
||||
"torchvision==0.18.1",
|
||||
"tqdm==4.66.5",
|
||||
"transformers==4.40.1",
|
||||
"moviepy==1.0.3",
|
||||
"av==12.3.0",
|
||||
"xformers==0.0.27",
|
||||
"gradio==4.39.0",
|
||||
"timm==0.9.10",
|
||||
"scikit-learn==1.5.1",
|
||||
"open-clip-torch==2.22.0",
|
||||
"kornia==0.7.3",
|
||||
"diffusers==0.30.2",
|
||||
"termcolor==2.4.0",
|
||||
"draccus==0.11.5",
|
||||
"accelerate==1.7.0",
|
||||
"tensorflow-metadata==1.16.1",
|
||||
"protobuf==3.20.3",
|
||||
"datasets==3.6.0",
|
||||
"tensorflow-graphics==2021.12.3",
|
||||
"fairscale==0.4.13"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=65.6.3", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = { "" = "src" }
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
114
run_all_cases.sh
Executable file
@@ -0,0 +1,114 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 自动执行所有场景的所有case
|
||||
# 总共5个场景,每个场景4个case,共20个case
|
||||
# 设置环境变量(离线模式)
|
||||
export HF_HUB_OFFLINE=1
|
||||
export TRANSFORMERS_OFFLINE=1
|
||||
|
||||
# 颜色定义
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# 定义所有场景
|
||||
SCENARIOS=(
|
||||
"unitree_g1_pack_camera"
|
||||
"unitree_z1_dual_arm_cleanup_pencils"
|
||||
"unitree_z1_dual_arm_stackbox"
|
||||
"unitree_z1_dual_arm_stackbox_v2"
|
||||
"unitree_z1_stackbox"
|
||||
)
|
||||
|
||||
# 定义case数量
|
||||
CASES=(1 2 3 4)
|
||||
|
||||
# 记录开始时间
|
||||
START_TIME=$(date +%s)
|
||||
LOG_FILE="run_all_cases_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo -e "${BLUE}开始执行所有场景的case${NC}"
|
||||
echo -e "${BLUE}总共: ${#SCENARIOS[@]} 个场景 x ${#CASES[@]} 个case = $((${#SCENARIOS[@]} * ${#CASES[@]})) 个任务${NC}"
|
||||
echo -e "${BLUE}日志文件: ${LOG_FILE}${NC}"
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo ""
|
||||
|
||||
# 初始化计数器
|
||||
TOTAL_CASES=$((${#SCENARIOS[@]} * ${#CASES[@]}))
|
||||
CURRENT_CASE=0
|
||||
SUCCESS_COUNT=0
|
||||
FAIL_COUNT=0
|
||||
|
||||
# 记录失败的case
|
||||
declare -a FAILED_CASES
|
||||
|
||||
# 遍历所有场景
|
||||
for scenario in "${SCENARIOS[@]}"; do
|
||||
echo -e "${YELLOW}>>> 场景: ${scenario}${NC}"
|
||||
|
||||
# 遍历所有case
|
||||
for case_num in "${CASES[@]}"; do
|
||||
CURRENT_CASE=$((CURRENT_CASE + 1))
|
||||
case_dir="${scenario}/case${case_num}"
|
||||
script_path="${case_dir}/run_world_model_interaction.sh"
|
||||
|
||||
echo -e "${BLUE}[${CURRENT_CASE}/${TOTAL_CASES}] 执行: ${case_dir}${NC}"
|
||||
|
||||
# 检查脚本是否存在
|
||||
if [ ! -f "${script_path}" ]; then
|
||||
echo -e "${RED}错误: 脚本不存在 ${script_path}${NC}"
|
||||
FAIL_COUNT=$((FAIL_COUNT + 1))
|
||||
FAILED_CASES+=("${case_dir} (脚本不存在)")
|
||||
continue
|
||||
fi
|
||||
|
||||
# 执行脚本
|
||||
echo "开始时间: $(date '+%Y-%m-%d %H:%M:%S')"
|
||||
|
||||
if bash "${script_path}" >> "${LOG_FILE}" 2>&1; then
|
||||
echo -e "${GREEN}✓ 成功: ${case_dir}${NC}"
|
||||
SUCCESS_COUNT=$((SUCCESS_COUNT + 1))
|
||||
else
|
||||
echo -e "${RED}✗ 失败: ${case_dir}${NC}"
|
||||
FAIL_COUNT=$((FAIL_COUNT + 1))
|
||||
FAILED_CASES+=("${case_dir}")
|
||||
fi
|
||||
|
||||
echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')"
|
||||
echo ""
|
||||
done
|
||||
|
||||
echo ""
|
||||
done
|
||||
|
||||
# 计算总耗时
|
||||
END_TIME=$(date +%s)
|
||||
DURATION=$((END_TIME - START_TIME))
|
||||
HOURS=$((DURATION / 3600))
|
||||
MINUTES=$(((DURATION % 3600) / 60))
|
||||
SECONDS=$((DURATION % 60))
|
||||
|
||||
# 输出总结
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo -e "${BLUE}执行完成!${NC}"
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo -e "总任务数: ${TOTAL_CASES}"
|
||||
echo -e "${GREEN}成功: ${SUCCESS_COUNT}${NC}"
|
||||
echo -e "${RED}失败: ${FAIL_COUNT}${NC}"
|
||||
echo -e "总耗时: ${HOURS}小时 ${MINUTES}分钟 ${SECONDS}秒"
|
||||
echo -e "详细日志: ${LOG_FILE}"
|
||||
echo ""
|
||||
|
||||
# 如果有失败的case,列出来
|
||||
if [ ${FAIL_COUNT} -gt 0 ]; then
|
||||
echo -e "${RED}失败的case列表:${NC}"
|
||||
for failed_case in "${FAILED_CASES[@]}"; do
|
||||
echo -e "${RED} - ${failed_case}${NC}"
|
||||
done
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
541
scripts/evaluation/base_model_inference.py
Normal file
@@ -0,0 +1,541 @@
|
||||
import argparse, os, glob
|
||||
import datetime, time
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import random
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from PIL import Image
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||
"""
|
||||
Get list of files in `data_dir` with extensions in `postfixes`.
|
||||
|
||||
Args:
|
||||
data_dir (str): Directory path.
|
||||
postfixes (list[str]): List of file extensions (e.g., ['csv', 'jpg']).
|
||||
|
||||
Returns:
|
||||
list[str]: Sorted list of matched file paths.
|
||||
"""
|
||||
patterns = [
|
||||
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
|
||||
]
|
||||
file_list = []
|
||||
for pattern in patterns:
|
||||
file_list.extend(glob.glob(pattern))
|
||||
file_list.sort()
|
||||
return file_list
|
||||
|
||||
|
||||
def load_model_checkpoint(model: torch.nn.Module,
|
||||
ckpt: str) -> torch.nn.Module:
|
||||
"""
|
||||
Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to load weights into.
|
||||
ckpt (str): Path to the checkpoint file.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: Model with weights loaded.
|
||||
"""
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
if "state_dict" in list(state_dict.keys()):
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
loaded = model.load_state_dict(state_dict, strict=False)
|
||||
print("Missing keys:")
|
||||
for k in loaded.missing_keys:
|
||||
print(f" {k}")
|
||||
print("Unexpected keys:")
|
||||
for k in loaded.unexpected_keys:
|
||||
print(f" {k}")
|
||||
|
||||
except:
|
||||
# Rename the keys for 256x256 model
|
||||
new_pl_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_pl_sd[k] = v
|
||||
|
||||
for k in list(new_pl_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_key = k.replace("framestride_embed", "fps_embedding")
|
||||
new_pl_sd[new_key] = new_pl_sd[k]
|
||||
del new_pl_sd[k]
|
||||
model.load_state_dict(new_pl_sd, strict=False)
|
||||
else:
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in state_dict['module'].keys():
|
||||
new_pl_sd[key[16:]] = state_dict['module'][key]
|
||||
model.load_state_dict(new_pl_sd)
|
||||
print('>>> model checkpoint loaded.')
|
||||
return model
|
||||
|
||||
|
||||
def load_prompts(prompt_file: str) -> list[str]:
|
||||
"""
|
||||
Load prompts from a text file, one per line.
|
||||
|
||||
Args:
|
||||
prompt_file (str): Path to the prompt file.
|
||||
|
||||
Returns:
|
||||
list[str]: List of prompt strings.
|
||||
"""
|
||||
f = open(prompt_file, 'r')
|
||||
prompt_list = []
|
||||
for idx, line in enumerate(f.readlines()):
|
||||
l = line.strip()
|
||||
if len(l) != 0:
|
||||
prompt_list.append(l)
|
||||
f.close()
|
||||
return prompt_list
|
||||
|
||||
|
||||
def load_data_prompts(
|
||||
data_dir: str,
|
||||
savedir: str,
|
||||
video_size: tuple[int, int] = (256, 256),
|
||||
video_frames: int = 16
|
||||
) -> tuple[list[str], list[torch.Tensor], list[str], list[float], list[float],
|
||||
list[int]]:
|
||||
"""
|
||||
Load image prompts, process them into video format, and retrieve metadata.
|
||||
|
||||
Args:
|
||||
data_dir (str): Directory containing images and CSV file.
|
||||
savedir (str): Output directory to check if inference was already done.
|
||||
video_size (tuple[int, int], optional): Target size of video frames.
|
||||
video_frames (int, optional): Number of frames in each video.
|
||||
|
||||
Returns:
|
||||
tuple: (filenames, video tensors, prompts, fps values, fs values, num_generations)
|
||||
"""
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(min(video_size)),
|
||||
transforms.CenterCrop(video_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
# Load prompt csv
|
||||
prompt_file = get_filelist(data_dir, ['csv'])
|
||||
assert len(prompt_file) > 0, "Error: found NO image prompt file!"
|
||||
|
||||
# Load image prompts
|
||||
file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
|
||||
data_list = []
|
||||
filename_list = []
|
||||
prompt_list = []
|
||||
fps_list = []
|
||||
fs_list = []
|
||||
num_gen_list = []
|
||||
prompt_csv = pd.read_csv(prompt_file[0])
|
||||
n_samples = len(file_list)
|
||||
|
||||
for idx in range(n_samples):
|
||||
image = Image.open(file_list[idx]).convert('RGB')
|
||||
image_tensor = transform(image).unsqueeze(1)
|
||||
frame_tensor = repeat(image_tensor,
|
||||
'c t h w -> c (repeat t) h w',
|
||||
repeat=video_frames)
|
||||
_, filename = os.path.split(file_list[idx])
|
||||
|
||||
if not is_inferenced(savedir, filename):
|
||||
video_id = filename[:-4]
|
||||
prompt_csv['videoid'] = prompt_csv['videoid'].map(str)
|
||||
if not (prompt_csv['videoid'] == video_id).any():
|
||||
continue
|
||||
data_list.append(frame_tensor)
|
||||
filename_list.append(filename)
|
||||
ins = prompt_csv[prompt_csv['videoid'] ==
|
||||
video_id]['instruction'].values[0]
|
||||
prompt_list.append(ins)
|
||||
fps = prompt_csv[prompt_csv['videoid'] ==
|
||||
video_id]['fps'].values[0]
|
||||
fps_list.append(fps)
|
||||
fs = prompt_csv[prompt_csv['videoid'] == video_id]['fs'].values[0]
|
||||
fs_list.append(fs)
|
||||
num_gen = prompt_csv[prompt_csv['videoid'] ==
|
||||
video_id]['num_gen'].values[0]
|
||||
num_gen_list.append(int(num_gen))
|
||||
|
||||
return filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list
|
||||
|
||||
|
||||
def is_inferenced(save_dir: str, filename: str) -> bool:
|
||||
"""
|
||||
Check if a result video already exists.
|
||||
|
||||
Args:
|
||||
save_dir (str): Directory where results are saved.
|
||||
filename (str): Base filename to check.
|
||||
|
||||
Returns:
|
||||
bool: True if file exists, else False.
|
||||
"""
|
||||
video_file = os.path.join(save_dir, f"{filename[:-4]}.mp4")
|
||||
return os.path.exists(video_file)
|
||||
|
||||
|
||||
def save_results_seperate(prompt: str | list[str],
|
||||
samples: torch.Tensor,
|
||||
filename: str,
|
||||
fakedir: str,
|
||||
fps: int = 8) -> None:
|
||||
"""
|
||||
Save generated video samples as .mp4 files.
|
||||
|
||||
Args:
|
||||
prompt (str | list[str]): The prompt text.
|
||||
samples (torch.Tensor): Generated video tensor of shape [B, C, T, H, W].
|
||||
filename (str): Output filename.
|
||||
fakedir (str): Directory to save output videos.
|
||||
fps (int, optional): Frames per second.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
|
||||
# Save video
|
||||
videos = [samples]
|
||||
savedirs = [fakedir]
|
||||
for idx, video in enumerate(videos):
|
||||
if video is None:
|
||||
continue
|
||||
video = video.detach().cpu()
|
||||
video = torch.clamp(video.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
for i in range(n):
|
||||
grid = video[i, ...]
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0)
|
||||
path = os.path.join(savedirs[idx], f'{filename.split(".")[0]}.mp4')
|
||||
torchvision.io.write_video(path,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '0'})
|
||||
|
||||
|
||||
def get_latent_z(model: torch.nn.Module, videos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode videos to latent space.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model with encode_first_stage function.
|
||||
videos (torch.Tensor): Video tensor of shape [B, C, T, H, W].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Latent representation of shape [B, C, T, H, W].
|
||||
"""
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
|
||||
|
||||
def image_guided_synthesis(model: torch.nn.Module,
|
||||
prompts: list[str],
|
||||
videos: torch.Tensor,
|
||||
noise_shape: list[int],
|
||||
ddim_steps: int = 50,
|
||||
ddim_eta: float = 1.0,
|
||||
unconditional_guidance_scale: float = 1.0,
|
||||
fs: int | None = None,
|
||||
text_input: bool = False,
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Run DDIM-based image-to-video synthesis with hybrid/text+image guidance.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Diffusion model.
|
||||
prompts (list[str]): Text prompts.
|
||||
videos (torch.Tensor): Input images/videos of shape [B, C, T, H, W].
|
||||
noise_shape (list[int]): Latent noise shape [B, C, T, H, W].
|
||||
ddim_steps (int, optional): Number of DDIM steps.
|
||||
ddim_eta (float, optional): Eta value for DDIM.
|
||||
unconditional_guidance_scale (float, optional): Guidance scale.
|
||||
fs (int | None, optional): FPS input for sampler.
|
||||
text_input (bool, optional): If True, use text guidance.
|
||||
timestep_spacing (str, optional): Timestep schedule spacing.
|
||||
guidance_rescale (float, optional): Rescale guidance effect.
|
||||
**kwargs: Additional sampler args.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Synthesized videos of shape [B, 1, C, T, H, W].
|
||||
"""
|
||||
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
batch_size = noise_shape[0]
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
if not text_input:
|
||||
prompts = [""] * batch_size
|
||||
|
||||
b, c, t, h, w = videos.shape
|
||||
img = videos[:, :, 0]
|
||||
img_emb = model.embedder(img)
|
||||
img_emb = model.image_proj_model(img_emb)
|
||||
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
|
||||
cond_emb = model.get_learned_conditioning(prompts)
|
||||
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, videos)
|
||||
img_cat_cond = z[:, :, :1, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=z.shape[2])
|
||||
cond["c_concat"] = [img_cat_cond]
|
||||
|
||||
uc = None
|
||||
cond_mask = None
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
|
||||
batch_variants = []
|
||||
if ddim_sampler is not None:
|
||||
samples, _, _, _ = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
conditioning=cond,
|
||||
eta=ddim_eta,
|
||||
mask=cond_mask,
|
||||
x0=None,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants.append(batch_images)
|
||||
|
||||
batch_variants = torch.stack(batch_variants)
|
||||
return batch_variants.permute(1, 0, 2, 3, 4, 5)
|
||||
|
||||
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
"""
|
||||
Run inference pipeline on prompts and image inputs.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Parsed command-line arguments.
|
||||
gpu_num (int): Number of GPUs.
|
||||
gpu_no (int): Index of the current GPU.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Load config
|
||||
config = OmegaConf.load(args.config)
|
||||
# Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model = model.cuda(gpu_no)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model.eval()
|
||||
|
||||
# Run over data
|
||||
assert (args.height % 16 == 0) and (
|
||||
args.width % 16
|
||||
== 0), "Error: image size [h,w] should be multiples of 16!"
|
||||
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
||||
|
||||
# Get latent noise shape
|
||||
h, w = args.height // 8, args.width // 8
|
||||
channels = model.model.diffusion_model.out_channels
|
||||
n_frames = args.video_length
|
||||
print(f'>>> Generate {n_frames} frames under each generation ...')
|
||||
noise_shape = [args.bs, channels, n_frames, h, w]
|
||||
|
||||
fakedir = os.path.join(args.savedir, "samples")
|
||||
os.makedirs(fakedir, exist_ok=True)
|
||||
|
||||
# Prompt file setting
|
||||
assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
|
||||
filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list = load_data_prompts(
|
||||
args.prompt_dir,
|
||||
args.savedir,
|
||||
video_size=(args.height, args.width),
|
||||
video_frames=n_frames)
|
||||
|
||||
num_samples = len(prompt_list)
|
||||
samples_split = num_samples // gpu_num
|
||||
print('>>> Prompts testing [rank:%d] %d/%d samples loaded.' %
|
||||
(gpu_no, samples_split, num_samples))
|
||||
|
||||
indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1)))
|
||||
fps_list_rank = [fps_list[i] for i in indices]
|
||||
fs_list_rank = [fs_list[i] for i in indices]
|
||||
prompt_list_rank = [prompt_list[i] for i in indices]
|
||||
data_list_rank = [data_list[i] for i in indices]
|
||||
filename_list_rank = [filename_list[i] for i in indices]
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
# Create a new result csv
|
||||
for idx, indice in enumerate(
|
||||
tqdm(range(0, len(prompt_list_rank), args.bs),
|
||||
desc=f'Sample batch')):
|
||||
fps = fps_list_rank[indice:indice + args.bs]
|
||||
fs = fs_list_rank[indice:indice + args.bs]
|
||||
prompts = prompt_list_rank[indice:indice + args.bs]
|
||||
num_gen = num_gen_list[indice:indice + args.bs]
|
||||
videos = data_list_rank[indice:indice + args.bs]
|
||||
filenames = filename_list_rank[indice:indice + args.bs]
|
||||
if isinstance(videos, list):
|
||||
videos = torch.stack(videos, dim=0).to("cuda")
|
||||
else:
|
||||
videos = videos.unsqueeze(0).to("cuda")
|
||||
|
||||
results = []
|
||||
print(
|
||||
f">>> {prompts[0]}, frame_stride:{fs[0]}, and {num_gen[0]} generation ..."
|
||||
)
|
||||
for _ in range(num_gen[0]):
|
||||
batch_samples = image_guided_synthesis(
|
||||
model, prompts, videos, noise_shape, args.ddim_steps,
|
||||
args.ddim_eta, args.unconditional_guidance_scale,
|
||||
fps[0] // fs[0], args.text_input, args.timestep_spacing,
|
||||
args.guidance_rescale)
|
||||
results.extend(batch_samples)
|
||||
videos = repeat(batch_samples[0][:, :, -1, :, :].unsqueeze(2),
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=batch_samples[0].shape[2])
|
||||
batch_samples = [torch.concat(results, axis=2)]
|
||||
|
||||
# Save each example individually
|
||||
for nn, samples in enumerate(batch_samples):
|
||||
prompt = prompts[nn]
|
||||
filename = filenames[nn]
|
||||
save_results_seperate(prompt,
|
||||
samples,
|
||||
filename,
|
||||
fakedir,
|
||||
fps=8)
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Create and return the argument parser.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Parser for command-line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--savedir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the results.")
|
||||
parser.add_argument("--ckpt_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the model checkpoint.")
|
||||
parser.add_argument("--config",
|
||||
type=str,
|
||||
help="Path to the YAML configuration file.")
|
||||
parser.add_argument(
|
||||
"--prompt_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory containing videos and corresponding prompts.")
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
|
||||
parser.add_argument("--bs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size for inference. Must be 1.")
|
||||
parser.add_argument("--height",
|
||||
type=int,
|
||||
default=320,
|
||||
help="Height of the generated images in pixels.")
|
||||
parser.add_argument("--width",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Width of the generated images in pixels.")
|
||||
parser.add_argument(
|
||||
"--unconditional_guidance_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale for classifier-free guidance during sampling.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=123,
|
||||
help="Random seed for reproducibility.")
|
||||
parser.add_argument("--video_length",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of frames in the generated video.")
|
||||
parser.add_argument(
|
||||
"--text_input",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
"Whether to provide a text prompt as input to the image-to-video model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_spacing",
|
||||
type=str,
|
||||
default="uniform",
|
||||
help=
|
||||
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_rescale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help=
|
||||
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perframe_ae",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
seed = args.seed
|
||||
if seed < 0:
|
||||
seed = random.randint(0, 2**31)
|
||||
seed_everything(seed)
|
||||
rank, gpu_num = 0, 1
|
||||
run_inference(args, gpu_num, rank)
|
||||
77
scripts/evaluation/eval_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import warnings
|
||||
import torchvision
|
||||
import sys
|
||||
import pyarrow as pa
|
||||
import logging
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, ClassVar, Deque, Mapping, Union
|
||||
from datasets.features.features import register_feature
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
"""
|
||||
Provides a type for a dataset containing video frames.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
||||
features = {"image": VideoFrame()}
|
||||
Dataset.from_dict(data_dict, features=Features(features))
|
||||
```
|
||||
"""
|
||||
|
||||
pa_type: ClassVar[Any] = pa.struct({
|
||||
"path": pa.string(),
|
||||
"timestamp": pa.float32()
|
||||
})
|
||||
_type: str = field(default="VideoFrame", init=False, repr=False)
|
||||
|
||||
def __call__(self):
|
||||
return self.pa_type
|
||||
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
||||
category=UserWarning,
|
||||
)
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def populate_queues(
|
||||
queues: Dict[str, Deque[Any]],
|
||||
batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]:
|
||||
|
||||
for key in batch:
|
||||
if key not in queues:
|
||||
continue
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
while len(queues[key]) != queues[key].maxlen:
|
||||
queues[key].append(batch[key])
|
||||
else:
|
||||
queues[key].append(batch[key])
|
||||
return queues
|
||||
|
||||
|
||||
def log_to_tensorboard(
|
||||
writer: SummaryWriter,
|
||||
data: Union[torch.Tensor, Any],
|
||||
tag: str,
|
||||
fps: int = 10) -> None:
|
||||
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||
video = data
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
writer.add_video(tag, grid, fps=fps)
|
||||
463
scripts/evaluation/real_eval_server.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import argparse, os, sys
|
||||
import torch
|
||||
import torchvision
|
||||
import warnings
|
||||
import imageio
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend('agg')
|
||||
import traceback
|
||||
import uvicorn
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import nn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Any, Dict, Optional, Tuple, List
|
||||
from datetime import datetime
|
||||
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
|
||||
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
"""Get a module's device by checking one of its parameters.
|
||||
|
||||
Args:
|
||||
module (nn.Module): PyTorch module.
|
||||
|
||||
Returns:
|
||||
torch.device: The device where the module's parameters are stored.
|
||||
"""
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to load weights into.
|
||||
ckpt (str): Path to checkpoint file.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with loaded weights.
|
||||
"""
|
||||
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
if "state_dict" in list(state_dict.keys()):
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
except:
|
||||
new_pl_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_pl_sd[k] = v
|
||||
|
||||
for k in list(new_pl_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_key = k.replace("framestride_embed", "fps_embedding")
|
||||
new_pl_sd[new_key] = new_pl_sd[k]
|
||||
del new_pl_sd[k]
|
||||
model.load_state_dict(new_pl_sd, strict=False)
|
||||
else:
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in state_dict['module'].keys():
|
||||
new_pl_sd[key[16:]] = state_dict['module'][key]
|
||||
model.load_state_dict(new_pl_sd)
|
||||
print('>>> model checkpoint loaded.')
|
||||
return model
|
||||
|
||||
|
||||
def write_video(video_path: str, stacked_frames: List[Any], fps: int) -> None:
|
||||
"""Write a video to disk using imageio.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to save the video.
|
||||
stacked_frames (List[Any]): Frames to write.
|
||||
fps (int): Frames per second.
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore",
|
||||
"pkg_resources is deprecated as an API",
|
||||
category=DeprecationWarning)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
|
||||
def save_results(video: torch.Tensor, filename: str, fps: int = 8) -> None:
|
||||
"""Save a video tensor as an MP4 file.
|
||||
|
||||
Args:
|
||||
video (torch.Tensor): Video tensor of shape (B, C, T, H, W).
|
||||
filename (str): Path to save video.
|
||||
fps (int, optional): Frame rate. Defaults to 8.
|
||||
|
||||
"""
|
||||
video = video.detach().cpu()
|
||||
video = torch.clamp(video.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(filename,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def get_latent_z(model: nn.Module, videos: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode videos into latent space.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model with `encode_first_stage` method.
|
||||
videos (torch.Tensor): Input videos (B, C, T, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Latent representation (B, C, T, H, W).
|
||||
|
||||
"""
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
|
||||
|
||||
def image_guided_synthesis(
|
||||
model: torch.nn.Module,
|
||||
prompts: list[str],
|
||||
observation: Dict[str, torch.Tensor],
|
||||
noise_shape: tuple[int, int, int, int, int],
|
||||
ddim_steps: int = 50,
|
||||
ddim_eta: float = 1.0,
|
||||
unconditional_guidance_scale: float = 1.0,
|
||||
fs: int | None = None,
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
**kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Run inference with DDIM sampling.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Diffusion model.
|
||||
prompts (Any): Conditioning text prompts.
|
||||
observation (Dict[str, torch.Tensor]): Observation dictionary.
|
||||
noise_shape (List[int]): Shape of noise tensor.
|
||||
ddim_steps (int, optional): Number of DDIM steps. Defaults to 50.
|
||||
ddim_eta (float, optional): Sampling eta. Defaults to 1.0.
|
||||
unconditional_guidance_scale (float, optional): Guidance scale. Defaults to 1.0.
|
||||
fs (Optional[int], optional): Frame stride or FPS. Defaults to None.
|
||||
timestep_spacing (str, optional): Spacing strategy. Defaults to "uniform".
|
||||
guidance_rescale (float, optional): Guidance rescale. Defaults to 0.0.
|
||||
**kwargs (Any): Additional arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
||||
b, _, t, _, _ = noise_shape
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
batch_size = noise_shape[0]
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
img = observation['observation.images.top']
|
||||
cond_img = img[:, -1, ...]
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
img_cat_cond = z[:, :, -1:, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=noise_shape[2])
|
||||
cond = {"c_concat": [img_cat_cond]}
|
||||
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
cond_state = model.state_projector(observation['observation.state'])
|
||||
cond_state_emb = model.agent_state_pos_emb + cond_state
|
||||
|
||||
cond_action = model.action_projector(observation['action'])
|
||||
cond_action_emb = model.agent_action_pos_emb + cond_action
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
|
||||
cond["c_crossattn"] = [
|
||||
torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1)
|
||||
]
|
||||
cond["c_crossattn_action"] = [
|
||||
observation['observation.images.top'].permute(
|
||||
0, 2, 1, 3, 4)[:, :, -model.n_obs_steps_acting:],
|
||||
observation['observation.state'][:, -model.n_obs_steps_acting:]
|
||||
]
|
||||
|
||||
uc = None
|
||||
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
|
||||
if ddim_sampler is not None:
|
||||
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
cfg_img=None,
|
||||
mask=cond_mask,
|
||||
x0=cond_z0,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
|
||||
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int,
|
||||
gpu_no: int) -> Tuple[nn.Module, List[int], Any]:
|
||||
"""
|
||||
Run inference pipeline on prompts and image inputs.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Parsed command-line arguments.
|
||||
gpu_num (int): Number of GPUs.
|
||||
gpu_no (int): Index of the current GPU.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Load config
|
||||
config = OmegaConf.load(args.config)
|
||||
# Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
|
||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model = model.cuda(gpu_no)
|
||||
model.eval()
|
||||
print(">>> Model is successfully loaded ...")
|
||||
|
||||
# Build unnomalizer
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
|
||||
## Run over data
|
||||
assert (args.height % 16 == 0) and (
|
||||
args.width % 16
|
||||
== 0), "Error: image size [h,w] should be multiples of 16!"
|
||||
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
||||
|
||||
## Get latent noise shape
|
||||
h, w = args.height // 8, args.width // 8
|
||||
channels = model.model.diffusion_model.out_channels
|
||||
n_frames = args.video_length
|
||||
print(f'>>> Generate {n_frames} frames under each generation ...')
|
||||
noise_shape = [args.bs, channels, n_frames, h, w]
|
||||
|
||||
return model, noise_shape, data
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--savedir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the results.")
|
||||
parser.add_argument("--ckpt_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the model checkpoint.")
|
||||
parser.add_argument("--config", type=str, help="Path to the config file.")
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
|
||||
parser.add_argument("--bs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size for inference. Must be 1.")
|
||||
parser.add_argument("--height",
|
||||
type=int,
|
||||
default=320,
|
||||
help="Height of the generated images in pixels.")
|
||||
parser.add_argument("--width",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Width of the generated images in pixels.")
|
||||
parser.add_argument(
|
||||
"--frame_stride",
|
||||
type=int,
|
||||
default=3,
|
||||
help=
|
||||
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unconditional_guidance_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale for classifier-free guidance during sampling.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=123,
|
||||
help="Random seed for reproducibility.")
|
||||
parser.add_argument("--video_length",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of frames in the generated video.")
|
||||
parser.add_argument(
|
||||
"--timestep_spacing",
|
||||
type=str,
|
||||
default="uniform",
|
||||
help=
|
||||
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_rescale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help=
|
||||
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perframe_ae",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class Server:
|
||||
|
||||
def __init__(self, args: argparse.Namespace) -> None:
|
||||
self.model_, self.noise_shape_, self.data_ = run_inference(args, 1, 0)
|
||||
self.args_ = args
|
||||
self.dataset_name = self.data_.dataset_configs['test']['params'][
|
||||
'dataset_name']
|
||||
self.device_ = get_device_from_parameters(self.model_)
|
||||
|
||||
def normalize_image(self, image: torch.Tensor) -> torch.Tensor:
|
||||
return (image / 255 - 0.5) * 2
|
||||
|
||||
def predict_action(self, payload: Dict[str, Any]) -> Any:
|
||||
try:
|
||||
images = payload['observation.images.top']
|
||||
states = payload['observation.state']
|
||||
actions = payload['action'] # Should be all zeros
|
||||
language_instruction = payload['language_instruction']
|
||||
|
||||
images = torch.tensor(images).cuda()
|
||||
images = self.data_.test_datasets[
|
||||
self.dataset_name].spatial_transform(images).unsqueeze(0)
|
||||
images = self.normalize_image(images)
|
||||
print(f"images shape: {images.shape} ...")
|
||||
states = torch.tensor(states)
|
||||
states = self.data_.test_datasets[self.dataset_name].normalizer(
|
||||
{'observation.state': states})['observation.state']
|
||||
states, _ = self.data_.test_datasets[
|
||||
self.dataset_name]._map_to_uni_state(states, "joint position")
|
||||
print(f"states shape: {states.shape} ...")
|
||||
actions = torch.tensor(actions)
|
||||
actions, action_mask = self.data_.test_datasets[
|
||||
self.dataset_name]._map_to_uni_action(actions,
|
||||
"joint position")
|
||||
print(f"actions shape: {actions.shape} ...")
|
||||
print("=" * 20)
|
||||
states = states.unsqueeze(0).cuda()
|
||||
actions = actions.unsqueeze(0).cuda()
|
||||
|
||||
observation = {
|
||||
'observation.images.top': images,
|
||||
'observation.state': states,
|
||||
'action': actions
|
||||
}
|
||||
observation = {
|
||||
key: observation[key].to(self.device_, non_blocking=True)
|
||||
for key in observation
|
||||
}
|
||||
|
||||
args = self.args_
|
||||
pred_videos, pred_action, _ = image_guided_synthesis(
|
||||
self.model_,
|
||||
language_instruction,
|
||||
observation,
|
||||
self.noise_shape_,
|
||||
ddim_steps=args.ddim_steps,
|
||||
ddim_ets=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=30 / args.frame_stride,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale)
|
||||
|
||||
pred_action = pred_action[..., action_mask[0] == 1.0][0].cpu()
|
||||
pred_action = self.data_.test_datasets[
|
||||
self.dataset_name].unnormalizer({'action':
|
||||
pred_action})['action']
|
||||
|
||||
os.makedirs(args.savedir, exist_ok=True)
|
||||
current_time = datetime.now().strftime("%H:%M:%S")
|
||||
video_file = f'{args.savedir}/{current_time}.mp4'
|
||||
save_results(pred_videos.cpu(), video_file)
|
||||
|
||||
response = {
|
||||
'result': 'ok',
|
||||
'action': pred_action.tolist(),
|
||||
'desc': 'success'
|
||||
}
|
||||
return JSONResponse(response)
|
||||
|
||||
except:
|
||||
logging.error(traceback.format_exc())
|
||||
logging.warning(
|
||||
"Your request threw an error; make sure your request complies with the expected format:\n"
|
||||
"{'image': np.ndarray, 'instruction': str}\n"
|
||||
"You can optionally an `unnorm_key: str` to specific the dataset statistics you want to use for "
|
||||
"de-normalizing the output actions.")
|
||||
return {'result': 'error', 'desc': traceback.format_exc()}
|
||||
|
||||
def run(self, host: str = "127.0.0.1", port: int = 8000) -> None:
|
||||
self.app = FastAPI()
|
||||
self.app.post("/predict_action")(self.predict_action)
|
||||
print(">>> Inference server is ready ... ")
|
||||
uvicorn.run(self.app, host=host, port=port)
|
||||
print(">>> Inference server stops ... ")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
seed = args.seed
|
||||
seed_everything(seed)
|
||||
rank, gpu_num = 0, 1
|
||||
print(">>> Launch inference server ... ")
|
||||
server = Server(args)
|
||||
server.run()
|
||||
1220
scripts/evaluation/world_model_interaction.py
Normal file
23
scripts/run_base_model_inference.sh
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
model_name=base_model
|
||||
ckpt=/path/to/base/model
|
||||
config=configs/inference/base_model_inference.yaml
|
||||
res_dir="/path/to/result/directory"
|
||||
seed=123
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/base_model_inference.py \
|
||||
--seed ${seed} \
|
||||
--ckpt_path $ckpt \
|
||||
--config $config \
|
||||
--savedir "${res_dir}/videos" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
--unconditional_guidance_scale 1.0 \
|
||||
--ddim_steps 16 \
|
||||
--ddim_eta 1.0 \
|
||||
--prompt_dir "/path/to/examples/base_model_prompts" \
|
||||
--text_input \
|
||||
--video_length 16 \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
26
scripts/run_real_eval_server.sh
Normal file
@@ -0,0 +1,26 @@
|
||||
model_name=testing
|
||||
ckpt=/path/to/model/checkpoint
|
||||
config=configs/inference/world_model_decision_making.yaml
|
||||
seed=123
|
||||
res_dir="path/to/results/directory"
|
||||
datasets=(
|
||||
"unitree_g1_pack_camera"
|
||||
)
|
||||
|
||||
|
||||
for dataset in "${datasets[@]}"; do
|
||||
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/real_eval_server.py \
|
||||
--seed ${seed} \
|
||||
--ckpt_path $ckpt \
|
||||
--config $config \
|
||||
--savedir "${res_dir}/${dataset}/${model_name}/videos" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
--unconditional_guidance_scale 1.0 \
|
||||
--ddim_steps 16 \
|
||||
--ddim_eta 1.0 \
|
||||
--video_length 16 \
|
||||
--frame_stride 2 \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
done
|
||||
42
scripts/run_world_model_interaction.sh
Normal file
@@ -0,0 +1,42 @@
|
||||
model_name=testing
|
||||
ckpt=/path/to/model/checkpoint
|
||||
config=configs/inference/world_model_interaction.yaml
|
||||
seed=123
|
||||
res_dir="/path/to/result/directory"
|
||||
|
||||
datasets=(
|
||||
"unitree_z1_stackbox"
|
||||
"unitree_z1_dual_arm_stackbox"
|
||||
"unitree_z1_dual_arm_stackbox_v2"
|
||||
"unitree_z1_dual_arm_cleanup_pencils"
|
||||
"unitree_g1_pack_camera"
|
||||
)
|
||||
|
||||
n_iters=(12 7 11 8 11)
|
||||
fses=(4 4 4 4 6)
|
||||
|
||||
for i in "${!datasets[@]}"; do
|
||||
dataset=${datasets[$i]}
|
||||
n_iter=${n_iters[$i]}
|
||||
fs=${fses[$i]}
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed ${seed} \
|
||||
--ckpt_path $ckpt \
|
||||
--config $config \
|
||||
--savedir "${res_dir}/${model_name}/${dataset}" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
--unconditional_guidance_scale 1.0 \
|
||||
--ddim_steps 50 \
|
||||
--ddim_eta 1.0 \
|
||||
--prompt_dir "/path/to/unifolm-world-model-action/examples/world_model_interaction_prompts" \
|
||||
--dataset ${dataset} \
|
||||
--video_length 16 \
|
||||
--frame_stride ${fs} \
|
||||
--n_action_steps 16 \
|
||||
--exe_steps 16 \
|
||||
--n_iter ${n_iter} \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
done
|
||||
32
scripts/train.sh
Normal file
@@ -0,0 +1,32 @@
|
||||
# NCCL configuration
|
||||
# export NCCL_DEBUG=debug
|
||||
# export NCCL_IB_DISABLE=0
|
||||
# export NCCL_IB_GID_INDEX=3
|
||||
# export NCCL_NET_GDR_LEVEL=3
|
||||
# export CUDA_LAUNCH_BLOCKING=1
|
||||
|
||||
# export NCCL_TOPO_FILE=/tmp/topo.txt
|
||||
# export MASTER_ADDR="master.ip."
|
||||
# export MASTER_PROT=12366
|
||||
|
||||
|
||||
# args
|
||||
name="experiment_name"
|
||||
config_file=configs/train/config.yaml
|
||||
|
||||
# save root dir for logs, checkpoints, tensorboard record, etc.
|
||||
save_root="/path/to/savedir"
|
||||
|
||||
mkdir -p $save_root/$name
|
||||
|
||||
## run
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
|
||||
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=12366 --node_rank=0 \
|
||||
./scripts/trainer.py \
|
||||
--base $config_file \
|
||||
--train \
|
||||
--name $name \
|
||||
--logdir $save_root \
|
||||
--devices 8 \
|
||||
--total_gpus=8 \
|
||||
lightning.trainer.num_nodes=1
|
||||
214
scripts/trainer.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import argparse, os, datetime
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import logging as transf_logging
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
||||
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
||||
|
||||
|
||||
def get_parser(**parser_kwargs):
|
||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||
parser.add_argument("--seed",
|
||||
"-s",
|
||||
type=int,
|
||||
default=20250912,
|
||||
help="seed for seed_everything")
|
||||
parser.add_argument("--name",
|
||||
"-n",
|
||||
type=str,
|
||||
default="",
|
||||
help="experiment name, as saving folder")
|
||||
parser.add_argument(
|
||||
"--base",
|
||||
"-b",
|
||||
nargs="*",
|
||||
metavar="base_config.yaml",
|
||||
help="paths to base configs. Loaded from left-to-right.",
|
||||
default=list())
|
||||
parser.add_argument("--train",
|
||||
"-t",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='train')
|
||||
parser.add_argument("--val",
|
||||
"-v",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='val')
|
||||
parser.add_argument("--test",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='test')
|
||||
parser.add_argument("--logdir",
|
||||
"-l",
|
||||
type=str,
|
||||
default="logs",
|
||||
help="directory for logging dat shit")
|
||||
parser.add_argument("--auto_resume",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="resume from full-info checkpoint")
|
||||
parser.add_argument("--auto_resume_weight_only",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="resume from weight-only checkpoint")
|
||||
parser.add_argument("--debug",
|
||||
"-d",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="enable post-mortem debugging")
|
||||
return parser
|
||||
|
||||
|
||||
def get_nondefault_trainer_args(args):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
default_trainer_args = parser.parse_args([])
|
||||
return sorted(k for k in vars(default_trainer_args)
|
||||
if getattr(args, k) != getattr(default_trainer_args, k))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
local_rank = int(os.environ.get('LOCAL_RANK'))
|
||||
global_rank = int(os.environ.get('RANK'))
|
||||
num_rank = int(os.environ.get('WORLD_SIZE'))
|
||||
|
||||
parser = get_parser()
|
||||
# Extends existing argparse by default Trainer attributes
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
args, unknown = parser.parse_known_args()
|
||||
transf_logging.set_verbosity_error()
|
||||
seed_everything(args.seed)
|
||||
|
||||
configs = [OmegaConf.load(cfg) for cfg in args.base]
|
||||
cli = OmegaConf.from_dotlist(unknown)
|
||||
config = OmegaConf.merge(*configs, cli)
|
||||
lightning_config = config.pop("lightning", OmegaConf.create())
|
||||
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
||||
|
||||
# Setup workspace directories
|
||||
workdir, ckptdir, cfgdir, loginfo = init_workspace(args.name, args.logdir,
|
||||
config,
|
||||
lightning_config,
|
||||
global_rank)
|
||||
logger = set_logger(
|
||||
logfile=os.path.join(loginfo, 'log_%d:%s.txt' % (global_rank, now)))
|
||||
logger.info("@lightning version: %s [>=1.8 required]" % (pl.__version__))
|
||||
logger.info("***** Configing Model *****")
|
||||
config.model.params.logdir = workdir
|
||||
model = instantiate_from_config(config.model)
|
||||
# Load checkpoints
|
||||
model = load_checkpoints(model, config.model)
|
||||
|
||||
# Register_schedule again to make ZTSNR work
|
||||
if model.rescale_betas_zero_snr:
|
||||
model.register_schedule(given_betas=model.given_betas,
|
||||
beta_schedule=model.beta_schedule,
|
||||
timesteps=model.timesteps,
|
||||
linear_start=model.linear_start,
|
||||
linear_end=model.linear_end,
|
||||
cosine_s=model.cosine_s)
|
||||
|
||||
# Update trainer config
|
||||
for k in get_nondefault_trainer_args(args):
|
||||
trainer_config[k] = getattr(args, k)
|
||||
|
||||
num_nodes = trainer_config.num_nodes
|
||||
ngpu_per_node = trainer_config.devices
|
||||
logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs")
|
||||
|
||||
# Setup learning rate
|
||||
base_lr = config.model.base_learning_rate
|
||||
bs = config.data.params.batch_size
|
||||
if getattr(config.model, 'scale_lr', True):
|
||||
model.learning_rate = num_rank * bs * base_lr
|
||||
else:
|
||||
model.learning_rate = base_lr
|
||||
|
||||
logger.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
for k in data.train_datasets:
|
||||
logger.info(
|
||||
f"{k}, {data.train_datasets[k].__class__.__name__}, {len(data.train_datasets[k])}"
|
||||
)
|
||||
if hasattr(data, 'val_datasets'):
|
||||
for k in data.val_datasets:
|
||||
logger.info(
|
||||
f"{k}, {data.val_datasets[k].__class__.__name__}, {len(data.val_datasets[k])}"
|
||||
)
|
||||
|
||||
for item in unknown:
|
||||
if item.startswith('--total_gpus'):
|
||||
num_gpus = int(item.split('=')[-1])
|
||||
break
|
||||
model.datasets_len = len(data)
|
||||
|
||||
logger.info("***** Configing Trainer *****")
|
||||
if "accelerator" not in trainer_config:
|
||||
trainer_config["accelerator"] = "gpu"
|
||||
|
||||
# Setup trainer args: pl-logger and callbacks
|
||||
trainer_kwargs = dict()
|
||||
trainer_kwargs["num_sanity_val_steps"] = 0
|
||||
logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug)
|
||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
||||
|
||||
# Setup callbacks
|
||||
callbacks_cfg = get_trainer_callbacks(lightning_config, config, workdir,
|
||||
ckptdir, logger)
|
||||
trainer_kwargs["callbacks"] = [
|
||||
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
|
||||
]
|
||||
strategy_cfg = get_trainer_strategy(lightning_config)
|
||||
trainer_kwargs["strategy"] = strategy_cfg if type(
|
||||
strategy_cfg) == str else instantiate_from_config(strategy_cfg)
|
||||
trainer_kwargs['precision'] = lightning_config.get('precision', 32)
|
||||
trainer_kwargs["sync_batchnorm"] = False
|
||||
|
||||
# Trainer config: others
|
||||
trainer_args = argparse.Namespace(**trainer_config)
|
||||
trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs)
|
||||
|
||||
# Allow checkpointing via USR1
|
||||
def melk(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
print("Summoning checkpoint.")
|
||||
ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
def divein(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
import pudb
|
||||
pudb.set_trace()
|
||||
|
||||
import signal
|
||||
signal.signal(signal.SIGUSR1, melk)
|
||||
signal.signal(signal.SIGUSR2, divein)
|
||||
|
||||
# List the key model sizes
|
||||
total_params = get_num_parameters(model)
|
||||
|
||||
logger.info("***** Running the Loop *****")
|
||||
if args.train:
|
||||
try:
|
||||
if "strategy" in lightning_config and lightning_config[
|
||||
'strategy'].startswith('deepspeed'):
|
||||
logger.info("<Training in DeepSpeed Mode>")
|
||||
if trainer_kwargs['precision'] == 16:
|
||||
with torch.cuda.amp.autocast():
|
||||
trainer.fit(model, data)
|
||||
else:
|
||||
trainer.fit(model, data)
|
||||
else:
|
||||
logger.info("<Training in DDPSharded Mode>")
|
||||
trainer.fit(model, data)
|
||||
except Exception:
|
||||
raise
|
||||
0
src/unifolm_wma/__init__.py
Normal file
26
src/unifolm_wma/data/base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from abc import abstractmethod
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
'''
|
||||
Define an interface to make the IterableDatasets for text2img data chainable
|
||||
'''
|
||||
|
||||
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||
super().__init__()
|
||||
self.num_records = num_records
|
||||
self.valid_ids = valid_ids
|
||||
self.sample_ids = valid_ids
|
||||
self.size = size
|
||||
|
||||
print(
|
||||
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_records
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
||||
230
src/unifolm_wma/data/normolize.py
Normal file
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
shapes: Dict[str, List[int]],
|
||||
modes: Dict[str, str],
|
||||
stats: Dict[str, Dict[str, Tensor]] = None,
|
||||
) -> Dict[str, Dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
Dict: A Dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
|
||||
shape = tuple(shapes[key])
|
||||
|
||||
if "image" in key:
|
||||
# sanity checks
|
||||
assert len(
|
||||
shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_Dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
if "action" in key:
|
||||
target_key = "action"
|
||||
elif "state" in key:
|
||||
target_key = 'observation.state'
|
||||
else:
|
||||
target_key = key
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict({
|
||||
"mean":
|
||||
nn.Parameter(mean, requires_grad=False),
|
||||
"std":
|
||||
nn.Parameter(std, requires_grad=False),
|
||||
})
|
||||
elif mode == "min_max":
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict({
|
||||
"min":
|
||||
nn.Parameter(min, requires_grad=False),
|
||||
"max":
|
||||
nn.Parameter(max, requires_grad=False),
|
||||
})
|
||||
|
||||
if stats is not None:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[target_key]["mean"].clone()
|
||||
buffer["std"].data = stats[target_key]["std"].clone()
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[target_key]["min"].clone()
|
||||
buffer["max"].data = stats[target_key]["max"].clone()
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model.")
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: Dict[str, List[int]],
|
||||
modes: Dict[str, str],
|
||||
stats: Dict[str, Dict[str, Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are Dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_Dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
for key, mode in self.modes.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: Dict[str, List[int]],
|
||||
modes: Dict[str, str],
|
||||
stats: Dict[str, Dict[str, Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are Dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_Dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
for key, mode in self.modes.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
60
src/unifolm_wma/data/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from typing import Dict, List, Union
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
|
||||
def unflatten_dict(d, sep="/"):
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
d = outdict
|
||||
for part in parts[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
d[parts[-1]] = value
|
||||
return outdict
|
||||
|
||||
|
||||
def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(
|
||||
root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id,
|
||||
"meta_data/episode_data_index.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id,
|
||||
"meta_data/stats.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
408
src/unifolm_wma/data/wma_data.py
Normal file
@@ -0,0 +1,408 @@
|
||||
import torch
|
||||
import os
|
||||
import random
|
||||
import pandas as pd
|
||||
import h5py
|
||||
|
||||
from decord import VideoReader, cpu
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from pathlib import Path
|
||||
|
||||
from unifolm_wma.data.utils import load_stats
|
||||
from unifolm_wma.data.normolize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class WMAData(Dataset):
|
||||
"""
|
||||
Assuming the following dataset structure:
|
||||
dataset_dir/
|
||||
├── videos
|
||||
│ ├──dataset_name
|
||||
│ │ ├──camera_view_dir
|
||||
│ │ ├── 0.mp4
|
||||
│ │ ├── 1.mp4
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset_name
|
||||
│ ├── meta_data
|
||||
│ ├── 0.h5
|
||||
│ ├── 1.h5
|
||||
│ └── ...
|
||||
└── dataset_name.csv
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta_path,
|
||||
data_dir,
|
||||
subsample=None,
|
||||
video_length=16,
|
||||
resolution=[256, 512],
|
||||
frame_stride=1,
|
||||
frame_stride_min=1,
|
||||
spatial_transform=None,
|
||||
crop_resolution=None,
|
||||
fps_max=None,
|
||||
load_raw_resolution=False,
|
||||
fixed_fps=None,
|
||||
random_fs=False,
|
||||
cond_robot_label_prob=0.0,
|
||||
transition_dir=None,
|
||||
dataset_name=None,
|
||||
normalization_mode='min_max',
|
||||
individual_normalization=False,
|
||||
n_obs_steps=1,
|
||||
max_action_dim=7,
|
||||
max_state_dim=7,
|
||||
):
|
||||
self.meta_path = meta_path
|
||||
self.data_dir = data_dir
|
||||
self.subsample = subsample
|
||||
self.video_length = video_length
|
||||
self.resolution = [resolution, resolution] if isinstance(
|
||||
resolution, int) else resolution
|
||||
self.fps_max = fps_max
|
||||
self.frame_stride = frame_stride
|
||||
self.frame_stride_min = frame_stride_min
|
||||
self.fixed_fps = fixed_fps
|
||||
self.load_raw_resolution = load_raw_resolution
|
||||
self.random_fs = random_fs
|
||||
self.cond_robot_label_prob = cond_robot_label_prob
|
||||
self.transition_dir = transition_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_action_dim = max_action_dim
|
||||
self.max_state_dim = max_state_dim
|
||||
|
||||
self._load_metadata()
|
||||
if spatial_transform is not None:
|
||||
if spatial_transform == "random_crop":
|
||||
self.spatial_transform = transforms.RandomCrop(crop_resolution)
|
||||
elif spatial_transform == "center_crop":
|
||||
self.spatial_transform = transforms.Compose([
|
||||
transforms.CenterCrop(resolution),
|
||||
])
|
||||
elif spatial_transform == "resize_center_crop":
|
||||
self.spatial_transform = transforms.Compose([
|
||||
transforms.Resize(min(self.resolution)),
|
||||
transforms.CenterCrop(self.resolution),
|
||||
])
|
||||
elif spatial_transform == "resize":
|
||||
self.spatial_transform = transforms.Resize(self.resolution)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.spatial_transform = None
|
||||
|
||||
self.normalization_mode = normalization_mode
|
||||
self.individual_normalization = individual_normalization
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self._load_stats()
|
||||
if individual_normalization:
|
||||
self._init_normalizers()
|
||||
|
||||
def _load_metadata(self):
|
||||
metadata = pd.read_csv(self.meta_path, dtype=str)
|
||||
if self.subsample is not None:
|
||||
metadata = metadata.sample(self.subsample, random_state=0)
|
||||
|
||||
self.metadata = metadata
|
||||
# drop the rows contain NaN values
|
||||
self.metadata.dropna(inplace=True)
|
||||
print(
|
||||
f">>> {metadata['data_dir'].iloc[0]}: {len(metadata)} data samples loaded."
|
||||
)
|
||||
|
||||
def _load_stats(self):
|
||||
self.stats = load_stats(self.dataset_name, None, self.transition_dir)
|
||||
print(f">>> {self.metadata['data_dir'].iloc[0]}: data stats loaded.")
|
||||
|
||||
def _init_normalizers(self):
|
||||
shape_dict = {
|
||||
'pre_action': [self.stats['action']['max'].shape[-1]],
|
||||
'action': [self.stats['action']['max'].shape[-1]],
|
||||
'observation.state':
|
||||
[self.stats['observation.state']['max'].shape[-1]],
|
||||
'next.state': [self.stats['observation.state']['max'].shape[-1]]
|
||||
}
|
||||
normalization_mode_dict = {
|
||||
'pre_action': self.normalization_mode,
|
||||
'action': self.normalization_mode,
|
||||
'observation.state': self.normalization_mode,
|
||||
'next.state': self.normalization_mode
|
||||
}
|
||||
self.normalizer = Normalize(shape_dict, normalization_mode_dict,
|
||||
self.stats)
|
||||
self.unnormalizer = Unnormalize(shape_dict, normalization_mode_dict,
|
||||
self.stats)
|
||||
print(
|
||||
f">>> {self.metadata['data_dir'].iloc[0]}: normalizer initiated.")
|
||||
|
||||
def _get_video_path(self, sample):
|
||||
rel_video_fp = os.path.join(sample['data_dir'],
|
||||
str(sample['videoid']) + '.mp4')
|
||||
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
|
||||
return full_video_fp
|
||||
|
||||
def _get_transition_path(self, sample):
|
||||
data_dir = Path(sample['data_dir'])
|
||||
if self.dataset_name == data_dir.name:
|
||||
rel_transition_fp = os.path.join(str(data_dir),
|
||||
str(sample['videoid']) + '.h5')
|
||||
else:
|
||||
rel_transition_fp = os.path.join(str(data_dir.parent),
|
||||
str(sample['videoid']) + '.h5')
|
||||
full_transition_fp = os.path.join(self.data_dir, 'transitions',
|
||||
rel_transition_fp)
|
||||
return full_transition_fp
|
||||
|
||||
def get_uni_vec(self, action_state_dict, action_type, state_type):
|
||||
if 'pre_action' in action_state_dict:
|
||||
action_state_dict['pre_action'], _ = self._map_to_uni_action(
|
||||
action_state_dict['pre_action'], action_type)
|
||||
if 'action' in action_state_dict:
|
||||
action_state_dict['action'], action_state_dict[
|
||||
'action_mask'] = self._map_to_uni_action(
|
||||
action_state_dict['action'], action_type)
|
||||
if 'observation.state' in action_state_dict:
|
||||
action_state_dict['observation.state'], _ = self._map_to_uni_state(
|
||||
action_state_dict['observation.state'], state_type)
|
||||
if 'next.state' in action_state_dict:
|
||||
action_state_dict['next.state'], action_state_dict[
|
||||
'state_mask'] = self._map_to_uni_state(
|
||||
action_state_dict['next.state'], state_type)
|
||||
return action_state_dict
|
||||
|
||||
def _map_to_uni_action(self, action, action_type):
|
||||
action_dim = action.shape[-1]
|
||||
uni_action = torch.nn.functional.pad(
|
||||
action, (0, self.max_action_dim - action_dim),
|
||||
mode='constant',
|
||||
value=0)
|
||||
uni_action_mask = torch.zeros_like(uni_action)
|
||||
uni_action_mask[:, :action_dim] = 1
|
||||
return uni_action, uni_action_mask
|
||||
|
||||
def _map_to_uni_state(self, state, state_type):
|
||||
state_dim = state.shape[-1]
|
||||
uni_state = torch.nn.functional.pad(
|
||||
state, (0, self.max_state_dim - state_dim),
|
||||
mode='constant',
|
||||
value=0)
|
||||
uni_state_mask = torch.zeros_like(uni_state)
|
||||
uni_state_mask[:, :state_dim] = 1
|
||||
return uni_state, uni_state_mask
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
if self.random_fs:
|
||||
frame_stride = random.randint(self.frame_stride_min,
|
||||
self.frame_stride)
|
||||
else:
|
||||
frame_stride = self.frame_stride
|
||||
|
||||
# Get frames until success
|
||||
while True:
|
||||
index = index % len(self.metadata)
|
||||
sample = self.metadata.iloc[index]
|
||||
video_path = self._get_video_path(sample)
|
||||
|
||||
instruction = sample['instruction']
|
||||
if self.cond_robot_label_prob > 0.0 and random.random(
|
||||
) < self.cond_robot_label_prob:
|
||||
if sample['embodiment'] != 'x':
|
||||
instruction = sample['embodiment'] + ' [SEP] ' + sample[
|
||||
'instruction']
|
||||
try:
|
||||
if self.load_raw_resolution:
|
||||
video_reader = VideoReader(video_path, ctx=cpu(0))
|
||||
else:
|
||||
video_reader = VideoReader(video_path,
|
||||
ctx=cpu(0),
|
||||
width=530,
|
||||
height=300)
|
||||
if len(video_reader) < self.video_length:
|
||||
print(
|
||||
f">>> Video length ({len(video_reader)}) is smaller than target length({self.video_length})"
|
||||
)
|
||||
index += 1
|
||||
continue
|
||||
else:
|
||||
pass
|
||||
except:
|
||||
index += 1
|
||||
print(f">>> Error: load video failed! path = {video_path}")
|
||||
continue
|
||||
|
||||
fps_ori = video_reader.get_avg_fps()
|
||||
if self.fixed_fps is not None:
|
||||
frame_stride = int(frame_stride *
|
||||
(1.0 * fps_ori / self.fixed_fps))
|
||||
|
||||
# To avoid extreme cases when fixed_fps is used
|
||||
frame_stride = max(frame_stride, 1)
|
||||
|
||||
# Get valid range (adapting case by case)
|
||||
required_frame_num = frame_stride * (self.video_length - 1) + 1
|
||||
frame_num = len(video_reader)
|
||||
if frame_num < required_frame_num:
|
||||
# Drop extra samples if fixed fps is required
|
||||
if self.fixed_fps is not None and frame_num < required_frame_num * 0.5:
|
||||
index += 1
|
||||
continue
|
||||
else:
|
||||
frame_stride = frame_num // self.video_length
|
||||
required_frame_num = frame_stride * (self.video_length -
|
||||
1) + 1
|
||||
|
||||
# Select a random clip
|
||||
random_range = frame_num - required_frame_num
|
||||
start_idx = random.randint(
|
||||
0, random_range -
|
||||
frame_stride) if random_range - frame_stride > 0 else 0
|
||||
|
||||
# Calculate frame indices
|
||||
frame_indices = [
|
||||
start_idx + frame_stride * i for i in range(self.video_length)
|
||||
]
|
||||
try:
|
||||
next_frame_indices = [
|
||||
idx + frame_stride for idx in frame_indices
|
||||
]
|
||||
frames = video_reader.get_batch(next_frame_indices)
|
||||
break
|
||||
except:
|
||||
print(
|
||||
f">>> Error: Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]"
|
||||
)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
# Load transition data
|
||||
transition_path = self._get_transition_path(sample)
|
||||
with h5py.File(transition_path, 'r') as h5f:
|
||||
transition_dict = {}
|
||||
for key in h5f.keys():
|
||||
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||
for key in h5f.attrs.keys():
|
||||
transition_dict[key] = h5f.attrs[key]
|
||||
|
||||
# Load observable states
|
||||
if start_idx < self.n_obs_steps - 1:
|
||||
state_indices = list(range(0, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
num_padding = self.n_obs_steps - 1 - start_idx
|
||||
first_slice = states[0:1, :] # (t, d)
|
||||
padding = first_slice.repeat(num_padding, 1)
|
||||
states = torch.cat((padding, states), dim=0)
|
||||
else:
|
||||
state_indices = list(
|
||||
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
assert states.shape[
|
||||
0] == self.n_obs_steps, '>>> Do not have enough previous states as observation.'
|
||||
|
||||
# Load observable actions
|
||||
if start_idx < self.n_obs_steps:
|
||||
pre_action_indices = list(range(0, start_idx))
|
||||
pre_actions = transition_dict['action'][pre_action_indices, :]
|
||||
num_padding = self.n_obs_steps - start_idx
|
||||
first_slice = torch.zeros_like(transition_dict['action'][:1, :])
|
||||
padding = first_slice.repeat(num_padding, 1)
|
||||
pre_actions = torch.cat((padding, pre_actions), dim=0)
|
||||
else:
|
||||
pre_action_indices = list(
|
||||
range(start_idx - self.n_obs_steps, start_idx))
|
||||
pre_actions = transition_dict['action'][pre_action_indices, :]
|
||||
assert pre_actions.shape[
|
||||
0] == self.n_obs_steps, ">>> Do not have enough previous actions as observation"
|
||||
|
||||
# Load future actions
|
||||
actions = transition_dict['action'][frame_indices, :]
|
||||
# Load future states
|
||||
next_state_indices = [idx + frame_stride for idx in frame_indices]
|
||||
next_states = transition_dict['observation.state'][
|
||||
next_state_indices, :]
|
||||
frames_action_state_dict = {
|
||||
'pre_action': pre_actions,
|
||||
'action': actions,
|
||||
'observation.state': states,
|
||||
'next.state': next_states
|
||||
}
|
||||
if self.individual_normalization:
|
||||
frames_action_state_dict = self.normalizer(
|
||||
frames_action_state_dict)
|
||||
|
||||
# Update action and states to unified vector
|
||||
frames_action_state_dict = self.get_uni_vec(
|
||||
frames_action_state_dict,
|
||||
transition_dict['action_type'],
|
||||
transition_dict['state_type'],
|
||||
)
|
||||
|
||||
# Load observable images
|
||||
if start_idx < self.n_obs_steps - 1:
|
||||
action_net_frame_indices = list(range(0, start_idx + 1))
|
||||
action_net_frames = video_reader.get_batch(
|
||||
action_net_frame_indices)
|
||||
action_net_frames = torch.tensor(
|
||||
action_net_frames.asnumpy()).permute(0, 3, 1, 2).float()
|
||||
first_slice = action_net_frames[0:1, :]
|
||||
num_padding = self.n_obs_steps - 1 - start_idx
|
||||
padding = first_slice.repeat(num_padding, 1, 1, 1)
|
||||
action_net_frames = torch.cat((padding, action_net_frames), dim=0)
|
||||
assert (
|
||||
action_net_frames.shape[0] == self.n_obs_steps
|
||||
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
|
||||
action_net_frames = action_net_frames.permute(1, 0, 2, 3)
|
||||
else:
|
||||
action_net_frame_indices = list(
|
||||
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
|
||||
action_net_frames = video_reader.get_batch(
|
||||
action_net_frame_indices)
|
||||
assert (
|
||||
action_net_frames.shape[0] == self.n_obs_steps
|
||||
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
|
||||
action_net_frames = torch.tensor(
|
||||
action_net_frames.asnumpy()).permute(3, 0, 1, 2).float()
|
||||
|
||||
assert (frames.shape[0] == self.video_length
|
||||
), f'{len(frames)}, self.video_length={self.video_length}'
|
||||
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
|
||||
|
||||
if self.spatial_transform is not None:
|
||||
frames = self.spatial_transform(frames)
|
||||
action_net_frames = self.spatial_transform(action_net_frames)
|
||||
|
||||
if self.resolution is not None:
|
||||
assert (frames.shape[2], frames.shape[3]) == (
|
||||
self.resolution[0], self.resolution[1]
|
||||
), f'frames={frames.shape}, self.resolution={self.resolution}'
|
||||
assert (
|
||||
action_net_frames.shape[2], action_net_frames.shape[3]
|
||||
) == (
|
||||
self.resolution[0], self.resolution[1]
|
||||
), f'action_net_frames={action_net_frames.shape}, self.resolution={self.resolution}'
|
||||
|
||||
# Normalize frames tensors to [-1,1]
|
||||
frames = (frames / 255 - 0.5) * 2
|
||||
action_net_frames = (action_net_frames / 255 - 0.5) * 2
|
||||
fps_clip = fps_ori // frame_stride
|
||||
if self.fps_max is not None and fps_clip > self.fps_max:
|
||||
fps_clip = self.fps_max
|
||||
|
||||
data = {
|
||||
'video': frames,
|
||||
'instruction': instruction,
|
||||
'path': video_path,
|
||||
'fps': fps_clip,
|
||||
'frame_stride': frame_stride,
|
||||
'observation.image': action_net_frames,
|
||||
}
|
||||
data.update(frames_action_state_dict)
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metadata)
|
||||
0
src/unifolm_wma/models/__init__.py
Normal file
267
src/unifolm_wma/models/autoencoder.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from einops import rearrange
|
||||
from unifolm_wma.modules.networks.ae_modules import Encoder, Decoder
|
||||
from unifolm_wma.utils.distributions import DiagonalGaussianDistribution
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
test=False,
|
||||
logdir=None,
|
||||
input_dim=4,
|
||||
test_args=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
self.input_dim = input_dim
|
||||
self.test = test
|
||||
self.test_args = test_args
|
||||
self.logdir = logdir
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer("colorize",
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
if self.test:
|
||||
self.init_test()
|
||||
|
||||
def init_test(self, ):
|
||||
self.test = True
|
||||
save_dir = os.path.join(self.logdir, "test")
|
||||
if 'ckpt' in self.test_args:
|
||||
ckpt_name = os.path.basename(self.test_args.ckpt).split(
|
||||
'.ckpt')[0] + f'_epoch{self._cur_epoch}'
|
||||
self.root = os.path.join(save_dir, ckpt_name)
|
||||
else:
|
||||
self.root = save_dir
|
||||
if 'test_subdir' in self.test_args:
|
||||
self.root = os.path.join(save_dir, self.test_args.test_subdir)
|
||||
|
||||
self.root_zs = os.path.join(self.root, "zs")
|
||||
self.root_dec = os.path.join(self.root, "reconstructions")
|
||||
self.root_inputs = os.path.join(self.root, "inputs")
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
|
||||
if self.test_args.save_z:
|
||||
os.makedirs(self.root_zs, exist_ok=True)
|
||||
if self.test_args.save_reconstruction:
|
||||
os.makedirs(self.root_dec, exist_ok=True)
|
||||
if self.test_args.save_input:
|
||||
os.makedirs(self.root_inputs, exist_ok=True)
|
||||
assert (self.test_args is not None)
|
||||
self.test_maximum = getattr(self.test_args, 'test_maximum', None)
|
||||
self.count = 0
|
||||
self.eval_metrics = {}
|
||||
self.decodes = []
|
||||
self.save_decode_samples = 2048
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
try:
|
||||
self._cur_epoch = sd['epoch']
|
||||
sd = sd["state_dict"]
|
||||
except:
|
||||
self._cur_epoch = 'null'
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z, **kwargs):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if x.dim() == 5 and self.input_dim == 4:
|
||||
b, c, t, h, w = x.shape
|
||||
self.b = b
|
||||
self.t = t
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train")
|
||||
self.log("aeloss",
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train")
|
||||
|
||||
self.log("discloss",
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val")
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
|
||||
list(self.decoder.parameters()) +
|
||||
list(self.quant_conv.parameters()) +
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize",
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
2524
src/unifolm_wma/models/ddpms.py
Normal file
0
src/unifolm_wma/models/diffusion_head/__init__.py
Normal file
217
src/unifolm_wma/models/diffusion_head/base_nets.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Contains torch Modules that correspond to basic network building blocks, like
|
||||
MLP, RNN, and CNN backbones.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
"""
|
||||
Base class for networks. The only difference from torch.nn.Module is that it
|
||||
requires implementing @output_shape.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def output_shape(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
"""
|
||||
================================================
|
||||
Visual Backbone Networks
|
||||
================================================
|
||||
"""
|
||||
|
||||
|
||||
class ConvBase(Module):
|
||||
"""
|
||||
Base class for ConvNets.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ConvBase, self).__init__()
|
||||
|
||||
# dirty hack - re-implement to pass the buck onto subclasses from ABC parent
|
||||
def output_shape(self, input_shape):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.nets(inputs)
|
||||
if list(self.output_shape(list(inputs.shape)[1:])) != list(
|
||||
x.shape)[1:]:
|
||||
raise ValueError('Size mismatch: expect size %s, but got size %s' %
|
||||
(str(self.output_shape(list(
|
||||
inputs.shape)[1:])), str(list(x.shape)[1:])))
|
||||
return x
|
||||
|
||||
|
||||
class SpatialSoftmax(ConvBase):
|
||||
"""
|
||||
Spatial Softmax Layer.
|
||||
|
||||
Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al.
|
||||
https://rll.berkeley.edu/dsae/dsae.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
num_kp=32,
|
||||
temperature=1.,
|
||||
learnable_temperature=False,
|
||||
output_variance=False,
|
||||
noise_std=0.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): shape of the input feature (C, H, W)
|
||||
num_kp (int): number of keypoints (None for not using spatialsoftmax)
|
||||
temperature (float): temperature term for the softmax.
|
||||
learnable_temperature (bool): whether to learn the temperature
|
||||
output_variance (bool): treat attention as a distribution, and compute second-order statistics to return
|
||||
noise_std (float): add random spatial noise to the predicted keypoints
|
||||
"""
|
||||
super(SpatialSoftmax, self).__init__()
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape # (C, H, W)
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._num_kp = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._num_kp = self._in_c
|
||||
self.learnable_temperature = learnable_temperature
|
||||
self.output_variance = output_variance
|
||||
self.noise_std = noise_std
|
||||
|
||||
if self.learnable_temperature:
|
||||
# temperature will be learned
|
||||
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
|
||||
requires_grad=True)
|
||||
self.register_parameter('temperature', temperature)
|
||||
else:
|
||||
# temperature held constant after initialization
|
||||
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
|
||||
requires_grad=False)
|
||||
self.register_buffer('temperature', temperature)
|
||||
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., self._in_w),
|
||||
np.linspace(-1., 1., self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h *
|
||||
self._in_w)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h *
|
||||
self._in_w)).float()
|
||||
self.register_buffer('pos_x', pos_x)
|
||||
self.register_buffer('pos_y', pos_y)
|
||||
|
||||
self.kps = None
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty print network."""
|
||||
header = format(str(self.__class__.__name__))
|
||||
return header + '(num_kp={}, temperature={}, noise={})'.format(
|
||||
self._num_kp, self.temperature.item(), self.noise_std)
|
||||
|
||||
def output_shape(self, input_shape):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
assert (len(input_shape) == 3)
|
||||
assert (input_shape[0] == self._in_c)
|
||||
return [self._num_kp, 2]
|
||||
|
||||
def forward(self, feature):
|
||||
"""
|
||||
Forward pass through spatial softmax layer. For each keypoint, a 2D spatial
|
||||
probability distribution is created using a softmax, where the support is the
|
||||
pixel locations. This distribution is used to compute the expected value of
|
||||
the pixel location, which becomes a keypoint of dimension 2. K such keypoints
|
||||
are created.
|
||||
|
||||
Returns:
|
||||
out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly
|
||||
keypoint variance of shape [B, K, 2, 2] corresponding to the covariance
|
||||
under the 2D spatial softmax distribution
|
||||
"""
|
||||
assert (feature.shape[1] == self._in_c)
|
||||
assert (feature.shape[2] == self._in_h)
|
||||
assert (feature.shape[3] == self._in_w)
|
||||
if self.nets is not None:
|
||||
feature = self.nets(feature)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
feature = feature.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(feature / self.temperature, dim=-1)
|
||||
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
|
||||
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
|
||||
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
|
||||
# stack to [B * K, 2]
|
||||
expected_xy = torch.cat([expected_x, expected_y], 1)
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._num_kp, 2)
|
||||
|
||||
if self.training:
|
||||
noise = torch.randn_like(feature_keypoints) * self.noise_std
|
||||
feature_keypoints += noise
|
||||
|
||||
if self.output_variance:
|
||||
# treat attention as a distribution, and compute second-order statistics to return
|
||||
expected_xx = torch.sum(self.pos_x * self.pos_x * attention,
|
||||
dim=1,
|
||||
keepdim=True)
|
||||
expected_yy = torch.sum(self.pos_y * self.pos_y * attention,
|
||||
dim=1,
|
||||
keepdim=True)
|
||||
expected_xy = torch.sum(self.pos_x * self.pos_y * attention,
|
||||
dim=1,
|
||||
keepdim=True)
|
||||
var_x = expected_xx - expected_x * expected_x
|
||||
var_y = expected_yy - expected_y * expected_y
|
||||
var_xy = expected_xy - expected_x * expected_y
|
||||
# stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix
|
||||
feature_covar = torch.cat([var_x, var_xy, var_xy, var_y],
|
||||
1).reshape(-1, self._num_kp, 2, 2)
|
||||
feature_keypoints = (feature_keypoints, feature_covar)
|
||||
|
||||
if isinstance(feature_keypoints, tuple):
|
||||
self.kps = (feature_keypoints[0].detach(),
|
||||
feature_keypoints[1].detach())
|
||||
else:
|
||||
self.kps = feature_keypoints.detach()
|
||||
return feature_keypoints
|
||||
83
src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from diffusers.optimization import (Union, SchedulerType, Optional, Optimizer,
|
||||
TYPE_TO_SCHEDULER_FUNCTION)
|
||||
|
||||
|
||||
def get_scheduler(name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Added kwargs vs diffuser's original implementation
|
||||
|
||||
Unified API to get any scheduler from its name.
|
||||
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer, **kwargs)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(
|
||||
f"{name} requires `num_warmup_steps`, please provide that argument."
|
||||
)
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
**kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(
|
||||
f"{name} requires `num_training_steps`, please provide that argument."
|
||||
)
|
||||
|
||||
return schedule_func(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
**kwargs)
|
||||
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import pytorch_lightning as pl
|
||||
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType
|
||||
|
||||
|
||||
class SelectiveLRScheduler(_LRScheduler):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
base_scheduler,
|
||||
group_indices,
|
||||
default_lr=[1e-5, 1e-4],
|
||||
last_epoch=-1):
|
||||
self.base_scheduler = base_scheduler
|
||||
self.group_indices = group_indices # Indices of parameter groups to update
|
||||
self.default_lr = default_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def step(self, epoch=None):
|
||||
self.base_scheduler.step()
|
||||
base_lrs = self.base_scheduler.get_last_lr()
|
||||
|
||||
for idx, group in enumerate(self.optimizer.param_groups):
|
||||
if idx in self.group_indices:
|
||||
group['lr'] = base_lrs[idx]
|
||||
else:
|
||||
# Reset the learning rate to its initial value
|
||||
group['lr'] = self.default_lr[idx]
|
||||
@@ -0,0 +1,16 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModuleAttrMixin(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._dummy_variable = nn.Parameter()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
91
src/unifolm_wma/models/diffusion_head/common/pytorch_util.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import collections
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Callable, List
|
||||
|
||||
|
||||
def dict_apply(
|
||||
x: Dict[str, torch.Tensor],
|
||||
func: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
result = dict()
|
||||
for key, value in x.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = dict_apply(value, func)
|
||||
else:
|
||||
result[key] = func(value)
|
||||
return result
|
||||
|
||||
|
||||
def pad_remaining_dims(x, target):
|
||||
assert x.shape == target.shape[:len(x.shape)]
|
||||
return x.reshape(x.shape + (1, ) * (len(target.shape) - len(x.shape)))
|
||||
|
||||
|
||||
def dict_apply_split(
|
||||
x: Dict[str, torch.Tensor], split_func: Callable[[torch.Tensor],
|
||||
Dict[str, torch.Tensor]]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
results = collections.defaultdict(dict)
|
||||
for key, value in x.items():
|
||||
result = split_func(value)
|
||||
for k, v in result.items():
|
||||
results[k][key] = v
|
||||
return results
|
||||
|
||||
|
||||
def dict_apply_reduce(
|
||||
x: List[Dict[str,
|
||||
torch.Tensor]], reduce_func: Callable[[List[torch.Tensor]],
|
||||
torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
result = dict()
|
||||
for key in x[0].keys():
|
||||
result[key] = reduce_func([x_[key] for x_ in x])
|
||||
return result
|
||||
|
||||
|
||||
def replace_submodules(root_module: nn.Module, predicate: Callable[[nn.Module],
|
||||
bool],
|
||||
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
|
||||
"""
|
||||
predicate: Return true if the module is to be replaced.
|
||||
func: Return new module to use.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
bn_list = [
|
||||
k.split('.')
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
for *parent, k in bn_list:
|
||||
parent_module = root_module
|
||||
if len(parent) > 0:
|
||||
parent_module = root_module.get_submodule('.'.join(parent))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
bn_list = [
|
||||
k.split('.')
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
assert len(bn_list) == 0
|
||||
return root_module
|
||||
|
||||
|
||||
def optimizer_to(optimizer, device):
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(device=device)
|
||||
return optimizer
|
||||
960
src/unifolm_wma/models/diffusion_head/common/tensor_util.py
Normal file
@@ -0,0 +1,960 @@
|
||||
"""
|
||||
A collection of utilities for working with nested tensor structures consisting
|
||||
of numpy arrays and torch tensors.
|
||||
"""
|
||||
import collections
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def recursive_dict_list_tuple_apply(x, type_func_dict):
|
||||
"""
|
||||
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
|
||||
{data_type: function_to_apply}.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
type_func_dict (dict): a mapping from data types to the functions to be
|
||||
applied for each data type.
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
assert (list not in type_func_dict)
|
||||
assert (tuple not in type_func_dict)
|
||||
assert (dict not in type_func_dict)
|
||||
|
||||
if isinstance(x, (dict, collections.OrderedDict)):
|
||||
new_x = collections.OrderedDict() if isinstance(
|
||||
x, collections.OrderedDict) else dict()
|
||||
for k, v in x.items():
|
||||
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
|
||||
return new_x
|
||||
elif isinstance(x, (list, tuple)):
|
||||
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
|
||||
if isinstance(x, tuple):
|
||||
ret = tuple(ret)
|
||||
return ret
|
||||
else:
|
||||
for t, f in type_func_dict.items():
|
||||
if isinstance(x, t):
|
||||
return f(x)
|
||||
else:
|
||||
raise NotImplementedError('Cannot handle data type %s' %
|
||||
str(type(x)))
|
||||
|
||||
|
||||
def map_tensor(x, func):
|
||||
"""
|
||||
Apply function @func to torch.Tensor objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each tensor
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor: func,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def map_ndarray(x, func):
|
||||
"""
|
||||
Apply function @func to np.ndarray objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
np.ndarray: func,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def map_tensor_ndarray(x, tensor_func, ndarray_func):
|
||||
"""
|
||||
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
|
||||
np.ndarray objects in a nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
tensor_func (function): function to apply to each tensor
|
||||
ndarray_Func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: tensor_func,
|
||||
np.ndarray: ndarray_func,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def clone(x):
|
||||
"""
|
||||
Clones all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.clone(),
|
||||
np.ndarray: lambda x: x.copy(),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def detach(x):
|
||||
"""
|
||||
Detaches all torch tensors in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor: lambda x: x.detach(),
|
||||
})
|
||||
|
||||
|
||||
def to_batch(x):
|
||||
"""
|
||||
Introduces a leading batch dimension of 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x[None, ...],
|
||||
np.ndarray: lambda x: x[None, ...],
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_sequence(x):
|
||||
"""
|
||||
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x[:, None, ...],
|
||||
np.ndarray: lambda x: x[:, None, ...],
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def index_at_time(x, ind):
|
||||
"""
|
||||
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
|
||||
nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
ind (int): index
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x[:, ind, ...],
|
||||
np.ndarray: lambda x: x[:, ind, ...],
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def unsqueeze(x, dim):
|
||||
"""
|
||||
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
|
||||
in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
dim (int): dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
|
||||
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def contiguous(x):
|
||||
"""
|
||||
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
|
||||
list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.contiguous(),
|
||||
np.ndarray: lambda x: np.ascontiguousarray(x),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_device(x, device):
|
||||
"""
|
||||
Sends all torch tensors in nested dictionary or list or tuple to device
|
||||
@device, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x, d=device: x.to(d),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_tensor(x):
|
||||
"""
|
||||
Converts all numpy arrays in nested dictionary or list or tuple to
|
||||
torch tensors (and leaves existing torch Tensors as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x,
|
||||
np.ndarray: lambda x: torch.from_numpy(x),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
"""
|
||||
Converts all torch tensors in nested dictionary or list or tuple to
|
||||
numpy (and leaves existing numpy arrays as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy()
|
||||
else:
|
||||
return tensor.detach().numpy()
|
||||
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_list(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to a list, and returns a new nested structure. Useful for
|
||||
json encoding.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy().tolist()
|
||||
else:
|
||||
return tensor.detach().numpy().tolist()
|
||||
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x.tolist(),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_float(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to float type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.float(),
|
||||
np.ndarray: lambda x: x.astype(np.float32),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_uint8(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to uint8 type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.byte(),
|
||||
np.ndarray: lambda x: x.astype(np.uint8),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_torch(x, device):
|
||||
"""
|
||||
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
|
||||
torch tensors on device @device and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return to_device(to_float(to_tensor(x)), device)
|
||||
|
||||
|
||||
def to_one_hot_single(tensor, num_class):
|
||||
"""
|
||||
Convert tensor to one-hot representation, assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): tensor containing integer labels
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): tensor containing one-hot representation of labels
|
||||
"""
|
||||
x = torch.zeros(tensor.size() + (num_class, )).to(tensor.device)
|
||||
x.scatter_(-1, tensor.unsqueeze(-1), 1)
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, num_class):
|
||||
"""
|
||||
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
|
||||
assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(tensor,
|
||||
func=lambda x, nc=num_class: to_one_hot_single(x, nc))
|
||||
|
||||
|
||||
def flatten_single(x, begin_axis=1):
|
||||
"""
|
||||
Flatten a tensor in all dimensions from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to flatten
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): flattened tensor
|
||||
"""
|
||||
fixed_size = x.size()[:begin_axis]
|
||||
_s = list(fixed_size) + [-1]
|
||||
return x.reshape(*_s)
|
||||
|
||||
|
||||
def flatten(x, begin_axis=1):
|
||||
"""
|
||||
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor:
|
||||
lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
|
||||
})
|
||||
|
||||
|
||||
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions in a tensor to a target dimension.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to reshape
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reshaped tensor
|
||||
"""
|
||||
assert (begin_axis <= end_axis)
|
||||
assert (begin_axis >= 0)
|
||||
assert (end_axis < len(x.shape))
|
||||
assert (isinstance(target_dims, (tuple, list)))
|
||||
s = x.shape
|
||||
final_s = []
|
||||
for i in range(len(s)):
|
||||
if i == begin_axis:
|
||||
final_s.extend(target_dims)
|
||||
elif i < begin_axis or i > end_axis:
|
||||
final_s.append(s[i])
|
||||
return x.reshape(*final_s)
|
||||
|
||||
|
||||
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
|
||||
to a target dimension.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor:
|
||||
lambda x, b=begin_axis, e=end_axis, t=target_dims:
|
||||
reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t),
|
||||
np.ndarray:
|
||||
lambda x, b=begin_axis, e=end_axis, t=target_dims:
|
||||
reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t),
|
||||
type(None):
|
||||
lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def join_dimensions(x, begin_axis, end_axis):
|
||||
"""
|
||||
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
|
||||
all tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor:
|
||||
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]),
|
||||
np.ndarray:
|
||||
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]),
|
||||
type(None):
|
||||
lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def expand_at_single(x, size, dim):
|
||||
"""
|
||||
Expand a tensor at a single dimension @dim by @size
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input tensor
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): expanded tensor
|
||||
"""
|
||||
assert dim < x.ndimension()
|
||||
assert x.shape[dim] == 1
|
||||
expand_dims = [-1] * x.ndimension()
|
||||
expand_dims[dim] = size
|
||||
return x.expand(*expand_dims)
|
||||
|
||||
|
||||
def expand_at(x, size, dim):
|
||||
"""
|
||||
Expand all tensors in nested dictionary or list or tuple at a single
|
||||
dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
|
||||
|
||||
|
||||
def unsqueeze_expand_at(x, size, dim):
|
||||
"""
|
||||
Unsqueeze and expand a tensor at a dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to unsqueeze and expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze(x, dim)
|
||||
return expand_at(x, size, dim)
|
||||
|
||||
|
||||
def repeat_by_expand_at(x, repeats, dim):
|
||||
"""
|
||||
Repeat a dimension by combining expand and reshape operations.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
repeats (int): number of times to repeat the target dimension
|
||||
dim (int): dimension to repeat on
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze_expand_at(x, repeats, dim + 1)
|
||||
return join_dimensions(x, dim, dim + 1)
|
||||
|
||||
|
||||
def named_reduce_single(x, reduction, dim):
|
||||
"""
|
||||
Reduce tensor at a dimension by named reduction functions.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to be reduced
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reduced tensor
|
||||
"""
|
||||
assert x.ndimension() > dim
|
||||
assert reduction in ["sum", "max", "mean", "flatten"]
|
||||
if reduction == "flatten":
|
||||
x = flatten(x, begin_axis=dim)
|
||||
elif reduction == "max":
|
||||
x = torch.max(x, dim=dim)[0] # [B, D]
|
||||
elif reduction == "sum":
|
||||
x = torch.sum(x, dim=dim)
|
||||
else:
|
||||
x = torch.mean(x, dim=dim)
|
||||
return x
|
||||
|
||||
|
||||
def named_reduce(x, reduction, dim):
|
||||
"""
|
||||
Reduces all tensors in nested dictionary or list or tuple at a dimension
|
||||
using a named reduction function.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(
|
||||
x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
|
||||
|
||||
|
||||
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
This function indexes out a target dimension of a tensor in a structured way,
|
||||
by allowing a different value to be selected for each member of a flat index
|
||||
tensor (@indices) corresponding to a source dimension. This can be interpreted
|
||||
as moving along the source dimension, using the corresponding index value
|
||||
in @indices to select values for all other dimensions outside of the
|
||||
source and target dimensions. A common use case is to gather values
|
||||
in target dimension 1 for each batch member (target dimension 0).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to gather values for
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
|
||||
"""
|
||||
assert len(indices.shape) == 1
|
||||
assert x.shape[source_dim] == indices.shape[0]
|
||||
|
||||
# unsqueeze in all dimensions except the source dimension
|
||||
new_shape = [1] * x.ndimension()
|
||||
new_shape[source_dim] = -1
|
||||
indices = indices.reshape(*new_shape)
|
||||
|
||||
# repeat in all dimensions - but preserve shape of source dimension,
|
||||
# and make sure target_dimension has singleton dimension
|
||||
expand_shape = list(x.shape)
|
||||
expand_shape[source_dim] = -1
|
||||
expand_shape[target_dim] = 1
|
||||
indices = indices.expand(*expand_shape)
|
||||
|
||||
out = x.gather(dim=target_dim, index=indices)
|
||||
return out.squeeze(target_dim)
|
||||
|
||||
|
||||
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
Apply @gather_along_dim_with_dim_single to all tensors in a nested
|
||||
dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x,
|
||||
lambda y, t=target_dim, s=source_dim, i=indices:
|
||||
gather_along_dim_with_dim_single(y, t, s, i))
|
||||
|
||||
|
||||
def gather_sequence_single(seq, indices):
|
||||
"""
|
||||
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
|
||||
the batch given an index for each sequence.
|
||||
|
||||
Args:
|
||||
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Return:
|
||||
y (torch.Tensor): indexed tensor of shape [B, ....]
|
||||
"""
|
||||
return gather_along_dim_with_dim_single(seq,
|
||||
target_dim=1,
|
||||
source_dim=0,
|
||||
indices=indices)
|
||||
|
||||
|
||||
def gather_sequence(seq, indices):
|
||||
"""
|
||||
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
|
||||
for tensors with leading dimensions [B, T, ...].
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
|
||||
"""
|
||||
return gather_along_dim_with_dim(seq,
|
||||
target_dim=1,
|
||||
source_dim=0,
|
||||
indices=indices)
|
||||
|
||||
|
||||
def pad_sequence_single(seq,
|
||||
padding,
|
||||
batched=False,
|
||||
pad_same=True,
|
||||
pad_values=None):
|
||||
"""
|
||||
Pad input tensor or array @seq in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (np.ndarray or torch.Tensor): sequence to be padded
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (np.ndarray or torch.Tensor)
|
||||
"""
|
||||
assert isinstance(seq, (np.ndarray, torch.Tensor))
|
||||
assert pad_same or pad_values is not None
|
||||
if pad_values is not None:
|
||||
assert isinstance(pad_values, float)
|
||||
repeat_func = np.repeat if isinstance(
|
||||
seq, np.ndarray) else torch.repeat_interleave
|
||||
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
|
||||
ones_like_func = np.ones_like if isinstance(
|
||||
seq, np.ndarray) else torch.ones_like
|
||||
seq_dim = 1 if batched else 0
|
||||
|
||||
begin_pad = []
|
||||
end_pad = []
|
||||
|
||||
if padding[0] > 0:
|
||||
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
|
||||
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
|
||||
if padding[1] > 0:
|
||||
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
|
||||
end_pad.append(repeat_func(pad, padding[1], seq_dim))
|
||||
|
||||
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
|
||||
|
||||
|
||||
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||
"""
|
||||
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (dict or list or tuple)
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
seq, {
|
||||
torch.Tensor:
|
||||
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
|
||||
pad_sequence_single(x, p, b, ps, pv),
|
||||
np.ndarray:
|
||||
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
|
||||
pad_sequence_single(x, p, b, ps, pv),
|
||||
type(None):
|
||||
lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def assert_size_at_dim_single(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that array or tensor @x has size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (np.ndarray or torch.Tensor): input array or tensor
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
msg (str): text to display if assertion fails
|
||||
"""
|
||||
assert x.shape[dim] == size, msg
|
||||
|
||||
|
||||
def assert_size_at_dim(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that arrays and tensors in nested dictionary or list or tuple have
|
||||
size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
"""
|
||||
map_tensor(
|
||||
x,
|
||||
lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
|
||||
|
||||
|
||||
def get_shape(x):
|
||||
"""
|
||||
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
|
||||
tensor's shape
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.shape,
|
||||
np.ndarray: lambda x: x.shape,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def list_of_flat_dict_to_dict_of_list(list_of_dict):
|
||||
"""
|
||||
Helper function to go from a list of flat dictionaries to a dictionary of lists.
|
||||
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
|
||||
floats, etc.
|
||||
|
||||
Args:
|
||||
list_of_dict (list): list of flat dictionaries
|
||||
|
||||
Returns:
|
||||
dict_of_list (dict): dictionary of lists
|
||||
"""
|
||||
assert isinstance(list_of_dict, list)
|
||||
dic = collections.OrderedDict()
|
||||
for i in range(len(list_of_dict)):
|
||||
for k in list_of_dict[i]:
|
||||
if k not in dic:
|
||||
dic[k] = []
|
||||
dic[k].append(list_of_dict[i][k])
|
||||
return dic
|
||||
|
||||
|
||||
def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):
|
||||
"""
|
||||
Flatten a nested dict or list to a list.
|
||||
|
||||
For example, given a dict
|
||||
{
|
||||
a: 1
|
||||
b: {
|
||||
c: 2
|
||||
}
|
||||
c: 3
|
||||
}
|
||||
|
||||
the function would return [(a, 1), (b_c, 2), (c, 3)]
|
||||
|
||||
Args:
|
||||
d (dict, list): a nested dict or list to be flattened
|
||||
parent_key (str): recursion helper
|
||||
sep (str): separator for nesting keys
|
||||
item_key (str): recursion helper
|
||||
Returns:
|
||||
list: a list of (key, value) tuples
|
||||
"""
|
||||
items = []
|
||||
if isinstance(d, (tuple, list)):
|
||||
new_key = parent_key + sep + item_key if len(
|
||||
parent_key) > 0 else item_key
|
||||
for i, v in enumerate(d):
|
||||
items.extend(
|
||||
flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
|
||||
return items
|
||||
elif isinstance(d, dict):
|
||||
new_key = parent_key + sep + item_key if len(
|
||||
parent_key) > 0 else item_key
|
||||
for k, v in d.items():
|
||||
assert isinstance(k, str)
|
||||
items.extend(
|
||||
flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
|
||||
return items
|
||||
else:
|
||||
new_key = parent_key + sep + item_key if len(
|
||||
parent_key) > 0 else item_key
|
||||
return [(new_key, d)]
|
||||
|
||||
|
||||
def time_distributed(inputs,
|
||||
op,
|
||||
activation=None,
|
||||
inputs_as_kwargs=False,
|
||||
inputs_as_args=False,
|
||||
**kwargs):
|
||||
"""
|
||||
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
|
||||
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
|
||||
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
|
||||
outputs to [B, T, ...].
|
||||
|
||||
Args:
|
||||
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
op: a layer op that accepts inputs
|
||||
activation: activation to apply at the output
|
||||
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
|
||||
inputs_as_args (bool) whether to feed input as a args list to the op
|
||||
kwargs (dict): other kwargs to supply to the op
|
||||
|
||||
Returns:
|
||||
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
|
||||
"""
|
||||
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
|
||||
inputs = join_dimensions(inputs, 0, 1)
|
||||
if inputs_as_kwargs:
|
||||
outputs = op(**inputs, **kwargs)
|
||||
elif inputs_as_args:
|
||||
outputs = op(*inputs, **kwargs)
|
||||
else:
|
||||
outputs = op(inputs, **kwargs)
|
||||
|
||||
if activation is not None:
|
||||
outputs = map_tensor(outputs, activation)
|
||||
outputs = reshape_dimensions(outputs,
|
||||
begin_axis=0,
|
||||
end_axis=0,
|
||||
target_dims=(batch_size, seq_len))
|
||||
return outputs
|
||||
701
src/unifolm_wma/models/diffusion_head/conditional_unet1d.py
Normal file
@@ -0,0 +1,701 @@
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import einops
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from typing import Union
|
||||
|
||||
from unifolm_wma.models.diffusion_head.conv1d_components import (
|
||||
Downsample1d, Upsample1d, Conv1dBlock)
|
||||
from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb
|
||||
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
|
||||
|
||||
from unifolm_wma.utils.basics import zero_module
|
||||
from unifolm_wma.utils.common import (
|
||||
checkpoint,
|
||||
exists,
|
||||
default,
|
||||
)
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
relative_position=False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout))
|
||||
|
||||
def efficient_forward(self, x, context=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
|
||||
q = self.to_q(x)
|
||||
if spatial_self_attn:
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out = (out.unsqueeze(0).reshape(
|
||||
b, self.heads, out.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
attention_cls=None):
|
||||
super().__init__()
|
||||
attn_cls = CrossAttention if attention_cls is None else attention_cls
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None)
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout)
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None, **kwargs):
|
||||
## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
|
||||
input_tuple = (
|
||||
x,
|
||||
) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
|
||||
if context is not None:
|
||||
input_tuple = (x, context)
|
||||
return checkpoint(self._forward, input_tuple, self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, mask=None):
|
||||
x = self.attn1(self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
mask=mask) + x
|
||||
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class ActionLatentImageCrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
in_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
use_checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
use_linear=True):
|
||||
super().__init__()
|
||||
"""
|
||||
in_channels: action input dim
|
||||
|
||||
"""
|
||||
self.in_channels = in_channels
|
||||
self.in_dim = in_dim
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=8,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
|
||||
self.proj_in_action = nn.Linear(in_dim, inner_dim)
|
||||
self.proj_in_cond = nn.Linear(context_dim, inner_dim)
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_dim))
|
||||
self.use_linear = use_linear
|
||||
|
||||
attention_cls = None
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
checkpoint=use_checkpoint,
|
||||
attention_cls=attention_cls)
|
||||
for d in range(depth)
|
||||
])
|
||||
|
||||
def forward(self, x, context=None, **kwargs):
|
||||
ba, ca, da = x.shape
|
||||
b, t, c, h, w = context.shape
|
||||
context = rearrange(context, 'b t c h w -> b (t h w) c').contiguous()
|
||||
|
||||
x_in = x
|
||||
x = self.norm(x) # ba x ja x d_in
|
||||
if self.use_linear:
|
||||
x = self.proj_in_action(x)
|
||||
context = self.proj_in_cond(context)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context, **kwargs)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
cond_dim,
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=True,
|
||||
use_linear_act_proj=False):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Conv1dBlock(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
n_groups=n_groups),
|
||||
Conv1dBlock(out_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
n_groups=n_groups),
|
||||
])
|
||||
|
||||
self.cond_predict_scale = cond_predict_scale
|
||||
self.use_linear_act_proj = use_linear_act_proj
|
||||
self.out_channels = out_channels
|
||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||
# predicts per-channel scale and bias
|
||||
cond_channels = out_channels
|
||||
if cond_predict_scale and use_linear_act_proj:
|
||||
cond_channels = out_channels * 2
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(cond_dim, cond_channels),
|
||||
)
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
|
||||
if in_channels != out_channels else nn.Identity()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
'''
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
'''
|
||||
B, T, _ = cond.shape
|
||||
|
||||
out = self.blocks[0](x)
|
||||
if self.cond_predict_scale:
|
||||
embed = self.cond_encoder(cond)
|
||||
if self.use_linear_act_proj:
|
||||
embed = embed.reshape(B * T, -1)
|
||||
embed = embed.reshape(-1, 2, self.out_channels, 1)
|
||||
else:
|
||||
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:, 0, ...]
|
||||
bias = embed[:, 1, ...]
|
||||
out = scale * out + bias
|
||||
# else:
|
||||
# out = out + embed
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
n_obs_steps=1,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=[256, 512, 1024],
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False,
|
||||
horizon=16,
|
||||
num_head_channels=64,
|
||||
use_linear_attn=True,
|
||||
use_linear_act_proj=True,
|
||||
act_proj_dim=32,
|
||||
cond_cross_attention=False,
|
||||
context_dims=None,
|
||||
image_size=None,
|
||||
imagen_cond_gradient=False,
|
||||
last_frame_only=False,
|
||||
use_imagen_mid_only=False,
|
||||
use_z_only=False,
|
||||
spatial_num_kp=32,
|
||||
obs_encoder_config=None):
|
||||
super().__init__()
|
||||
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.obs_encoder = instantiate_from_config(obs_encoder_config)
|
||||
|
||||
all_dims = [input_dim] + list(down_dims)
|
||||
start_dim = down_dims[0]
|
||||
|
||||
dsed = diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed + self.obs_encoder.output_shape()[-1] * self.n_obs_steps
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
||||
local_cond_encoder = None
|
||||
down_modules = nn.ModuleList([])
|
||||
|
||||
dim_a_list = []
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
if ind == 0:
|
||||
dim_a = horizon
|
||||
else:
|
||||
dim_a = horizon // 2 * ind
|
||||
dim_a_list.append(dim_a)
|
||||
|
||||
# for attention
|
||||
num_heads = dim_out // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if use_linear_act_proj:
|
||||
if use_imagen_mid_only:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[-1]
|
||||
elif use_z_only:
|
||||
cur_cond_dim = cond_dim + 2 * spatial_num_kp
|
||||
else:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[ind]
|
||||
else:
|
||||
cur_cond_dim = cond_dim + horizon * context_dims[ind]
|
||||
|
||||
down_modules.append(
|
||||
nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ActionLatentImageCrossAttention(
|
||||
dim_out,
|
||||
dim_a,
|
||||
num_heads,
|
||||
dim_head,
|
||||
context_dim=context_dims[ind],
|
||||
use_linear=use_linear_attn)
|
||||
if cond_cross_attention else nn.Identity(),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ActionLatentImageCrossAttention(mid_dim,
|
||||
dim_a_list[-1],
|
||||
num_heads,
|
||||
dim_head,
|
||||
context_dim=context_dims[-1],
|
||||
use_linear=use_linear_attn)
|
||||
if cond_cross_attention else nn.Identity(),
|
||||
])
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
context_dims = context_dims[::-1]
|
||||
for ind, (dim_in, dim_out) in enumerate(
|
||||
reversed(in_out[1:] + [(down_dims[-1], down_dims[-1])])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
if use_linear_act_proj:
|
||||
if use_imagen_mid_only:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[0]
|
||||
elif use_z_only:
|
||||
cur_cond_dim = cond_dim + 2 * spatial_num_kp
|
||||
else:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[ind]
|
||||
else:
|
||||
cur_cond_dim = cond_dim + horizon * context_dims[ind]
|
||||
up_modules.append(
|
||||
nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out + dim_in,
|
||||
dim_in,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_in,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ActionLatentImageCrossAttention(
|
||||
dim_in,
|
||||
dim_a_list.pop(),
|
||||
num_heads,
|
||||
dim_head,
|
||||
context_dim=context_dims[ind],
|
||||
use_linear=use_linear_attn)
|
||||
if cond_cross_attention else nn.Identity(),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||
nn.Conv1d(start_dim, input_dim, 1),
|
||||
)
|
||||
|
||||
if use_z_only:
|
||||
h, w = image_size
|
||||
self.spatial_softmax_blocks = nn.ModuleList(
|
||||
[SpatialSoftmax((4, h, w), spatial_num_kp)])
|
||||
else:
|
||||
self.spatial_softmax_blocks = nn.ModuleList([])
|
||||
context_dims = context_dims[::-1]
|
||||
for ind, context_dim in enumerate(context_dims):
|
||||
h, w = image_size
|
||||
if ind != 0:
|
||||
h //= 2**ind
|
||||
w //= 2**ind
|
||||
net = SpatialSoftmax((context_dim, h, w), context_dim)
|
||||
self.spatial_softmax_blocks.append(net)
|
||||
self.spatial_softmax_blocks.append(net)
|
||||
self.spatial_softmax_blocks += self.spatial_softmax_blocks[
|
||||
0:4][::-1]
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.local_cond_encoder = local_cond_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
self.cond_cross_attention = cond_cross_attention
|
||||
self.use_linear_act_proj = use_linear_act_proj
|
||||
|
||||
self.proj_in_action = nn.Sequential(nn.Linear(1, act_proj_dim),
|
||||
nn.LayerNorm(act_proj_dim))
|
||||
self.proj_in_horizon = nn.Sequential(nn.Linear(horizon, act_proj_dim),
|
||||
nn.LayerNorm(act_proj_dim))
|
||||
self.proj_out_action = nn.Sequential(nn.LayerNorm(act_proj_dim),
|
||||
nn.Linear(act_proj_dim, 1))
|
||||
self.proj_out_horizon = nn.Sequential(nn.LayerNorm(act_proj_dim),
|
||||
nn.Linear(act_proj_dim, horizon))
|
||||
logger.info("number of parameters: %e",
|
||||
sum(p.numel() for p in self.parameters()))
|
||||
|
||||
self.imagen_cond_gradient = imagen_cond_gradient
|
||||
self.use_imagen_mid_only = use_imagen_mid_only
|
||||
self.use_z_only = use_z_only
|
||||
self.spatial_num_kp = spatial_num_kp
|
||||
self.last_frame_only = last_frame_only
|
||||
self.horizon = horizon
|
||||
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
imagen_cond=None,
|
||||
cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
sample: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
imagen_cond: a list of hidden info from video gen unet
|
||||
cond: dict:
|
||||
image: (B, 3, To, h, w)
|
||||
agent_pos: (B, Ta, d)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
|
||||
if not self.imagen_cond_gradient:
|
||||
imagen_cond = [c.detach() for c in imagen_cond]
|
||||
|
||||
cond = {'image': cond[0], 'agent_pos': cond[1]}
|
||||
|
||||
cond['image'] = cond['image'].permute(0, 2, 1, 3,
|
||||
4)
|
||||
cond['image'] = rearrange(cond['image'], 'b t c h w -> (b t) c h w')
|
||||
cond['agent_pos'] = rearrange(cond['agent_pos'], 'b t d -> (b t) d')
|
||||
|
||||
B, T, D = sample.shape
|
||||
if self.use_linear_act_proj:
|
||||
sample = self.proj_in_action(sample.unsqueeze(-1))
|
||||
global_cond = self.obs_encoder(cond)
|
||||
global_cond = rearrange(global_cond,
|
||||
'(b t) d -> b 1 (t d)',
|
||||
b=B,
|
||||
t=self.n_obs_steps)
|
||||
global_cond = repeat(global_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=T)
|
||||
else:
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
sample = self.proj_in_horizon(sample)
|
||||
robo_state_cond = rearrange(robo_state_cond, 'b t d -> b 1 (t d)')
|
||||
robo_state_cond = repeat(robo_state_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=2)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps],
|
||||
dtype=torch.long,
|
||||
device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
||||
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
||||
|
||||
x = sample if not self.use_linear_act_proj else sample.reshape(
|
||||
B * T, D, -1)
|
||||
h = []
|
||||
for idx, modules in enumerate(self.down_modules):
|
||||
if self.cond_cross_attention:
|
||||
(resnet, resnet2, crossatten, downsample) = modules
|
||||
else:
|
||||
(resnet, resnet2, _, downsample) = modules
|
||||
|
||||
# Access the cond from the unet embeds from video unet
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = imagen_cond_mid
|
||||
elif self.use_z_only:
|
||||
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
imagen_cond = imagen_cond_down[idx]
|
||||
if self.last_frame_only:
|
||||
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
||||
imagen_cond = repeat(imagen_cond,
|
||||
'b t c h w -> b (repeat t) c h w',
|
||||
repeat=self.horizon)
|
||||
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[len(
|
||||
self.spatial_softmax_blocks) // 2](imagen_cond)
|
||||
elif self.use_z_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
||||
else:
|
||||
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=2, dim=1)
|
||||
cur_global_feature = torch.cat(
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
#>>> mide blocks
|
||||
resnet, resnet2, _ = self.mid_modules
|
||||
# Access the cond from the unet embeds from video unet
|
||||
if self.use_z_only:
|
||||
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
imagen_cond = imagen_cond_mid
|
||||
if self.last_frame_only:
|
||||
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
||||
imagen_cond = repeat(imagen_cond,
|
||||
'b t c h w -> b (repeat t) c h w',
|
||||
repeat=self.horizon)
|
||||
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
||||
idx += 1
|
||||
if self.use_z_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
||||
else:
|
||||
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
||||
repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
||||
repeats=2, dim=1)
|
||||
cur_global_feature = torch.cat(
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
|
||||
#>>> up blocks
|
||||
idx += 1
|
||||
for jdx, modules in enumerate(self.up_modules):
|
||||
if self.cond_cross_attention:
|
||||
(resnet, resnet2, crossatten, upsample) = modules
|
||||
else:
|
||||
(resnet, resnet2, _, upsample) = modules
|
||||
|
||||
# Access the cond from the unet embeds from video unet
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = imagen_cond_mid
|
||||
elif self.use_z_only:
|
||||
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
imagen_cond = imagen_cond_up[jdx]
|
||||
if self.last_frame_only:
|
||||
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
||||
imagen_cond = repeat(imagen_cond,
|
||||
'b t c h w -> b (repeat t) c h w',
|
||||
repeat=self.horizon)
|
||||
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[len(
|
||||
self.spatial_softmax_blocks) // 2](imagen_cond)
|
||||
elif self.use_z_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
||||
else:
|
||||
imagen_cond = self.spatial_softmax_blocks[jdx +
|
||||
idx](imagen_cond)
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=2, dim=1)
|
||||
|
||||
cur_global_feature = torch.cat(
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
x = x.reshape(B, T, D, -1)
|
||||
x = self.proj_out_action(x)
|
||||
x = x.reshape(B, T, D)
|
||||
else:
|
||||
x = self.proj_out_horizon(x)
|
||||
x = einops.rearrange(x, 'b t h -> b h t')
|
||||
return x
|
||||
52
src/unifolm_wma/models/diffusion_head/conv1d_components.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
'''
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
'''
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2),
|
||||
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
def test():
|
||||
cb = Conv1dBlock(256, 128, kernel_size=3)
|
||||
x = torch.zeros((1, 256, 16))
|
||||
o = cb(x)
|
||||
80
src/unifolm_wma/models/diffusion_head/ema_model.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
||||
class EMAModel:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
update_after_step=0,
|
||||
inv_gamma=1.0,
|
||||
power=2 / 3,
|
||||
min_value=0.0,
|
||||
max_value=0.9999):
|
||||
"""
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||
at 215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
|
||||
self.averaged_model = model
|
||||
self.averaged_model.eval()
|
||||
self.averaged_model.requires_grad_(False)
|
||||
|
||||
self.update_after_step = update_after_step
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
self.decay = 0.0
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
"""
|
||||
step = max(0, optimization_step - self.update_after_step - 1)
|
||||
value = 1 - (1 + step / self.inv_gamma)**-self.power
|
||||
|
||||
if step <= 0:
|
||||
return 0.0
|
||||
|
||||
return max(self.min_value, min(value, self.max_value))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, new_model):
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
all_dataptrs = set()
|
||||
for module, ema_module in zip(new_model.modules(),
|
||||
self.averaged_model.modules()):
|
||||
for param, ema_param in zip(module.parameters(recurse=False),
|
||||
ema_module.parameters(recurse=False)):
|
||||
# iterative over immediate parameters only.
|
||||
if isinstance(param, dict):
|
||||
raise RuntimeError('Dict parameter not supported')
|
||||
|
||||
if isinstance(module, _BatchNorm):
|
||||
# skip batchnorms
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
elif not param.requires_grad:
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
else:
|
||||
ema_param.mul_(self.decay)
|
||||
ema_param.add_(param.data.to(dtype=ema_param.dtype),
|
||||
alpha=1 - self.decay)
|
||||
|
||||
# verify that iterating over module and then parameters is identical to parameters recursively.
|
||||
# assert old_all_dataptrs == all_dataptrs
|
||||
self.optimization_step += 1
|
||||
@@ -0,0 +1,19 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
322
src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as ttf
|
||||
import unifolm_wma.models.diffusion_head.common.tensor_util as tu
|
||||
|
||||
|
||||
class CropRandomizer(nn.Module):
|
||||
"""
|
||||
Randomly sample crops at input, and then average across crop features at output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
crop_height,
|
||||
crop_width,
|
||||
num_crops=1,
|
||||
pos_enc=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape (tuple, list): shape of input (not including batch dimension)
|
||||
crop_height (int): crop height
|
||||
crop_width (int): crop width
|
||||
num_crops (int): number of random crops to take
|
||||
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
|
||||
location of the cropped pixels in the source image
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3 # (C, H, W)
|
||||
assert crop_height < input_shape[1]
|
||||
assert crop_width < input_shape[2]
|
||||
|
||||
self.input_shape = input_shape
|
||||
self.crop_height = crop_height
|
||||
self.crop_width = crop_width
|
||||
self.num_crops = num_crops
|
||||
self.pos_enc = pos_enc
|
||||
|
||||
def output_shape_in(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_in operation, where raw inputs (usually observation modalities)
|
||||
are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
|
||||
# the number of crops are reshaped into the batch dimension, increasing the batch
|
||||
# size from B to B * N
|
||||
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
|
||||
return [out_c, self.crop_height, self.crop_width]
|
||||
|
||||
def output_shape_out(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_out operation, where processed inputs (usually encoded observation
|
||||
modalities) are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
|
||||
# and then pools to result in [B, ...], only the batch dimension changes,
|
||||
# and so the other dimensions retain their shape.
|
||||
return list(input_shape)
|
||||
|
||||
def forward_in(self, inputs):
|
||||
"""
|
||||
Samples N random crops for each input in the batch, and then reshapes
|
||||
inputs to [B * N, ...].
|
||||
"""
|
||||
assert len(
|
||||
inputs.shape) >= 3 # must have at least (C, H, W) dimensions
|
||||
if self.training:
|
||||
# generate random crops
|
||||
out, _ = sample_random_image_crops(
|
||||
images=inputs,
|
||||
crop_height=self.crop_height,
|
||||
crop_width=self.crop_width,
|
||||
num_crops=self.num_crops,
|
||||
pos_enc=self.pos_enc,
|
||||
)
|
||||
# [B, N, ...] -> [B * N, ...]
|
||||
return tu.join_dimensions(out, 0, 1)
|
||||
else:
|
||||
# take center crop during eval
|
||||
out = ttf.center_crop(img=inputs,
|
||||
output_size=(self.crop_height,
|
||||
self.crop_width))
|
||||
if self.num_crops > 1:
|
||||
B, C, H, W = out.shape
|
||||
out = out.unsqueeze(1).expand(B, self.num_crops, C, H,
|
||||
W).reshape(-1, C, H, W)
|
||||
# [B * N, ...]
|
||||
return out
|
||||
|
||||
def forward_out(self, inputs):
|
||||
"""
|
||||
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
|
||||
to result in shape [B, ...] to make sure the network output is consistent with
|
||||
what would have happened if there were no randomization.
|
||||
"""
|
||||
if self.num_crops <= 1:
|
||||
return inputs
|
||||
else:
|
||||
batch_size = (inputs.shape[0] // self.num_crops)
|
||||
out = tu.reshape_dimensions(inputs,
|
||||
begin_axis=0,
|
||||
end_axis=0,
|
||||
target_dims=(batch_size,
|
||||
self.num_crops))
|
||||
return out.mean(dim=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.forward_in(inputs)
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty print network."""
|
||||
header = '{}'.format(str(self.__class__.__name__))
|
||||
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
|
||||
self.input_shape, self.crop_height, self.crop_width,
|
||||
self.num_crops)
|
||||
return msg
|
||||
|
||||
|
||||
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
|
||||
"""
|
||||
Crops images at the locations specified by @crop_indices. Crops will be
|
||||
taken across all channels.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
|
||||
N is the number of crops to take per image and each entry corresponds
|
||||
to the pixel height and width of where to take the crop. Note that
|
||||
the indices can also be of shape [..., 2] if only 1 crop should
|
||||
be taken per image. Leading dimensions must be consistent with
|
||||
@images argument. Each index specifies the top left of the crop.
|
||||
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
|
||||
H and W are the height and width of @images and CH and CW are
|
||||
@crop_height and @crop_width.
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
Returns:
|
||||
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
|
||||
"""
|
||||
|
||||
# make sure length of input shapes is consistent
|
||||
assert crop_indices.shape[-1] == 2
|
||||
ndim_im_shape = len(images.shape)
|
||||
ndim_indices_shape = len(crop_indices.shape)
|
||||
assert (ndim_im_shape == ndim_indices_shape +
|
||||
1) or (ndim_im_shape == ndim_indices_shape + 2)
|
||||
|
||||
# maybe pad so that @crop_indices is shape [..., N, 2]
|
||||
is_padded = False
|
||||
if ndim_im_shape == ndim_indices_shape + 2:
|
||||
crop_indices = crop_indices.unsqueeze(-2)
|
||||
is_padded = True
|
||||
|
||||
# make sure leading dimensions between images and indices are consistent
|
||||
assert images.shape[:-3] == crop_indices.shape[:-2]
|
||||
|
||||
device = images.device
|
||||
image_c, image_h, image_w = images.shape[-3:]
|
||||
num_crops = crop_indices.shape[-2]
|
||||
|
||||
# make sure @crop_indices are in valid range
|
||||
assert (crop_indices[..., 0] >= 0).all().item()
|
||||
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
|
||||
assert (crop_indices[..., 1] >= 0).all().item()
|
||||
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
|
||||
|
||||
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
|
||||
|
||||
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
|
||||
crop_ind_grid_h = torch.arange(crop_height).to(device)
|
||||
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h,
|
||||
size=crop_width,
|
||||
dim=-1)
|
||||
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
|
||||
crop_ind_grid_w = torch.arange(crop_width).to(device)
|
||||
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w,
|
||||
size=crop_height,
|
||||
dim=0)
|
||||
# combine into shape [CH, CW, 2]
|
||||
crop_in_grid = torch.cat(
|
||||
(crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
|
||||
|
||||
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
|
||||
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
|
||||
# shape array that tells us which pixels from the corresponding source image to grab.
|
||||
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [
|
||||
crop_height, crop_width, 2
|
||||
]
|
||||
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(
|
||||
-2) + crop_in_grid.reshape(grid_reshape)
|
||||
|
||||
# For using @torch.gather, convert to flat indices from 2D indices, and also
|
||||
# repeat across the channel dimension. To get flat index of each pixel to grab for
|
||||
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
|
||||
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[
|
||||
..., 1] # shape [..., N, CH, CW]
|
||||
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c,
|
||||
dim=-3) # shape [..., N, C, CH, CW]
|
||||
all_crop_inds = tu.flatten(all_crop_inds,
|
||||
begin_axis=-2) # shape [..., N, C, CH * CW]
|
||||
|
||||
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
|
||||
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
|
||||
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
|
||||
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
|
||||
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
|
||||
reshape_axis = len(crops.shape) - 1
|
||||
crops = tu.reshape_dimensions(crops,
|
||||
begin_axis=reshape_axis,
|
||||
end_axis=reshape_axis,
|
||||
target_dims=(crop_height, crop_width))
|
||||
|
||||
if is_padded:
|
||||
# undo padding -> [..., C, CH, CW]
|
||||
crops = crops.squeeze(-4)
|
||||
return crops
|
||||
|
||||
|
||||
def sample_random_image_crops(images,
|
||||
crop_height,
|
||||
crop_width,
|
||||
num_crops,
|
||||
pos_enc=False):
|
||||
"""
|
||||
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
|
||||
@images.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
num_crops (n): number of crops to sample
|
||||
|
||||
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
|
||||
encoding of the original source pixel locations. This means that the
|
||||
output crops will contain information about where in the source image
|
||||
it was sampled from.
|
||||
|
||||
Returns:
|
||||
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
|
||||
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
|
||||
|
||||
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
|
||||
"""
|
||||
device = images.device
|
||||
|
||||
# maybe add 2 channels of spatial encoding to the source image
|
||||
source_im = images
|
||||
if pos_enc:
|
||||
# spatial encoding [y, x] in [0, 1]
|
||||
h, w = source_im.shape[-2:]
|
||||
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
||||
pos_y = pos_y.float().to(device) / float(h)
|
||||
pos_x = pos_x.float().to(device) / float(w)
|
||||
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
|
||||
|
||||
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
|
||||
leading_shape = source_im.shape[:-3]
|
||||
position_enc = position_enc[(None, ) * len(leading_shape)]
|
||||
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
|
||||
|
||||
# concat across channel dimension with input
|
||||
source_im = torch.cat((source_im, position_enc), dim=-3)
|
||||
|
||||
# make sure sample boundaries ensure crops are fully within the images
|
||||
image_c, image_h, image_w = source_im.shape[-3:]
|
||||
max_sample_h = image_h - crop_height
|
||||
max_sample_w = image_w - crop_width
|
||||
|
||||
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
|
||||
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
|
||||
# we will sample [B, N] indices, but this supports having more than one leading dimension,
|
||||
# or possibly no leading dimension.
|
||||
#
|
||||
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
|
||||
crop_inds_h = (
|
||||
max_sample_h *
|
||||
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds_w = (
|
||||
max_sample_w *
|
||||
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds = torch.cat(
|
||||
(crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)),
|
||||
dim=-1) # shape [..., N, 2]
|
||||
|
||||
crops = crop_image_from_indices(
|
||||
images=source_im,
|
||||
crop_indices=crop_inds,
|
||||
crop_height=crop_height,
|
||||
crop_width=crop_width,
|
||||
)
|
||||
|
||||
return crops, crop_inds
|
||||
30
src/unifolm_wma/models/diffusion_head/vision/model_getter.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
|
||||
def get_resnet(name, weights=None, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
weights: "IMAGENET1K_V1", "r3m"
|
||||
"""
|
||||
# load r3m weights
|
||||
if (weights == "r3m") or (weights == "R3M"):
|
||||
return get_r3m(name=name, **kwargs)
|
||||
|
||||
func = getattr(torchvision.models, name)
|
||||
resnet = func(weights=weights, **kwargs)
|
||||
resnet.fc = torch.nn.Identity()
|
||||
return resnet
|
||||
|
||||
|
||||
def get_r3m(name, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
"""
|
||||
import r3m
|
||||
r3m.device = 'cpu'
|
||||
model = r3m.load_r3m(name)
|
||||
r3m_model = model.module
|
||||
resnet_model = r3m_model.convnet
|
||||
resnet_model = resnet_model.to('cpu')
|
||||
return resnet_model
|
||||
@@ -0,0 +1,247 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import json
|
||||
import os
|
||||
|
||||
from unifolm_wma.models.diffusion_head.vision.crop_randomizer import CropRandomizer
|
||||
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
|
||||
from unifolm_wma.models.diffusion_head.common.module_attr_mixin import ModuleAttrMixin
|
||||
from unifolm_wma.models.diffusion_head.common.pytorch_util import dict_apply, replace_submodules
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from einops import rearrange, repeat
|
||||
from typing import Dict, Tuple, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rgb_model_config: Dict,
|
||||
shape_meta_path: str | None = None,
|
||||
resize_shape: Union[Tuple[int, int], Dict[str, tuple],
|
||||
None] = None,
|
||||
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||
random_crop: bool = True,
|
||||
# replace BatchNorm with GroupNorm
|
||||
use_group_norm: bool = False,
|
||||
# use single rgb model for all rgb inputs
|
||||
share_rgb_model: bool = False,
|
||||
# renormalize rgb input with imagenet normalization
|
||||
# assuming input in [0,1]
|
||||
imagenet_norm: bool = False,
|
||||
use_spatial_softmax=False,
|
||||
spatial_softmax_kp=32,
|
||||
use_dinoSiglip=False):
|
||||
"""
|
||||
Assumes rgb input: B,C,H,W
|
||||
Assumes low_dim input: B,D
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if not shape_meta_path:
|
||||
shape_meta_path = str(Path(os.getcwd()) / "configs/train/meta.json")
|
||||
|
||||
with open(shape_meta_path, 'r') as file:
|
||||
shape_meta = json.load(file)
|
||||
|
||||
rgb_model = instantiate_from_config(rgb_model_config)
|
||||
|
||||
rgb_keys = list()
|
||||
low_dim_keys = list()
|
||||
key_model_map = nn.ModuleDict()
|
||||
key_transform_map = nn.ModuleDict()
|
||||
key_shape_map = dict()
|
||||
|
||||
# handle sharing vision backbone
|
||||
if share_rgb_model:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
key_model_map['rgb'] = rgb_model
|
||||
|
||||
obs_shape_meta = shape_meta['obs']
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr['shape'])
|
||||
type = attr.get('type', 'low_dim')
|
||||
key_shape_map[key] = shape
|
||||
if type == 'rgb':
|
||||
rgb_keys.append(key)
|
||||
if not use_dinoSiglip:
|
||||
# configure model for this key
|
||||
this_model = None
|
||||
if not share_rgb_model:
|
||||
if isinstance(rgb_model, dict):
|
||||
# have provided model for each key
|
||||
this_model = rgb_model[key]
|
||||
else:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
# have a copy of the rgb model
|
||||
this_model = copy.deepcopy(rgb_model)
|
||||
|
||||
if this_model is not None:
|
||||
if use_group_norm:
|
||||
this_model = replace_submodules(
|
||||
root_module=this_model,
|
||||
predicate=lambda x: isinstance(
|
||||
x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16,
|
||||
num_channels=x.num_features))
|
||||
key_model_map[key] = this_model
|
||||
|
||||
# configure resize
|
||||
input_shape = shape
|
||||
this_resizer = nn.Identity()
|
||||
if resize_shape is not None:
|
||||
if isinstance(resize_shape, dict):
|
||||
h, w = resize_shape[key]
|
||||
else:
|
||||
h, w = resize_shape
|
||||
this_resizer = torchvision.transforms.Resize(size=(h,
|
||||
w))
|
||||
input_shape = (shape[0], h, w)
|
||||
|
||||
# configure randomizer
|
||||
this_randomizer = nn.Identity()
|
||||
if crop_shape is not None:
|
||||
if isinstance(crop_shape, dict):
|
||||
h, w = crop_shape[key]
|
||||
else:
|
||||
h, w = crop_shape
|
||||
if random_crop:
|
||||
this_randomizer = CropRandomizer(
|
||||
input_shape=input_shape,
|
||||
crop_height=h,
|
||||
crop_width=w,
|
||||
num_crops=1,
|
||||
pos_enc=False)
|
||||
else:
|
||||
this_normalizer = torchvision.transforms.CenterCrop(
|
||||
size=(h, w))
|
||||
# configure normalizer
|
||||
this_normalizer = nn.Identity()
|
||||
if imagenet_norm:
|
||||
this_normalizer = torchvision.transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
this_transform = nn.Sequential(this_resizer,
|
||||
this_randomizer,
|
||||
this_normalizer)
|
||||
key_transform_map[key] = this_transform
|
||||
else:
|
||||
key_model_map[key] = rgb_model
|
||||
elif type == 'low_dim':
|
||||
low_dim_keys.append(key)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported obs type: {type}")
|
||||
|
||||
rgb_keys = sorted(rgb_keys)
|
||||
low_dim_keys = sorted(low_dim_keys)
|
||||
|
||||
self.shape_meta = shape_meta
|
||||
self.key_model_map = key_model_map
|
||||
self.key_transform_map = key_transform_map
|
||||
self.share_rgb_model = share_rgb_model
|
||||
self.rgb_keys = rgb_keys
|
||||
self.low_dim_keys = low_dim_keys
|
||||
self.key_shape_map = key_shape_map
|
||||
self.use_dinoSiglip = use_dinoSiglip
|
||||
|
||||
##NOTE add spatial softmax
|
||||
self.use_spatial_softmax = use_spatial_softmax
|
||||
if use_spatial_softmax and not use_dinoSiglip:
|
||||
model = nn.Sequential(
|
||||
key_model_map['image'].conv1,
|
||||
key_model_map['image'].bn1,
|
||||
key_model_map['image'].relu,
|
||||
key_model_map['image'].maxpool,
|
||||
key_model_map['image'].layer1,
|
||||
key_model_map['image'].layer2,
|
||||
key_model_map['image'].layer3,
|
||||
key_model_map['image'].layer4,
|
||||
)
|
||||
key_model_map['image'] = model
|
||||
input_shape = self.output_shape(resnet_output_shape=True)
|
||||
self.spatial_softmax = SpatialSoftmax(input_shape,
|
||||
num_kp=spatial_softmax_kp)
|
||||
|
||||
def forward(self, obs_dict, resnet_output_shape=False):
|
||||
batch_size = None
|
||||
features = list()
|
||||
# process rgb input
|
||||
if self.share_rgb_model:
|
||||
# pass all rgb obs to rgb model
|
||||
imgs = list()
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
imgs.append(img)
|
||||
# (N*B,C,H,W)
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
# (N*B,D)
|
||||
feature = self.key_model_map['rgb'](imgs)
|
||||
# (N,B,D)
|
||||
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
|
||||
# (B,N,D)
|
||||
feature = torch.moveaxis(feature, 0, 1)
|
||||
# (B,N*D)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
features.append(feature)
|
||||
else:
|
||||
# run each rgb obs to independent models
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
if not self.use_dinoSiglip:
|
||||
img = self.key_transform_map[key](img)
|
||||
feature = self.key_model_map[key](img)
|
||||
else:
|
||||
feature = self.key_model_map[key](img)[:, :1, :]
|
||||
|
||||
if resnet_output_shape:
|
||||
return feature
|
||||
if not self.use_dinoSiglip and self.use_spatial_softmax:
|
||||
feature = self.spatial_softmax(feature)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
features.append(feature)
|
||||
|
||||
# process lowdim input
|
||||
for key in self.low_dim_keys:
|
||||
data = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = data.shape[0]
|
||||
else:
|
||||
assert batch_size == data.shape[0]
|
||||
assert data.shape[1:] == self.key_shape_map[key]
|
||||
features.append(data)
|
||||
|
||||
# concatenate all features
|
||||
result = torch.cat(features, dim=-1)
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def output_shape(self, resnet_output_shape=False):
|
||||
example_obs_dict = dict()
|
||||
obs_shape_meta = self.shape_meta['obs']
|
||||
batch_size = 1
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr['shape'])
|
||||
this_obs = torch.zeros((batch_size, ) + shape,
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
example_obs_dict[key] = this_obs
|
||||
example_output = self.forward(example_obs_dict,
|
||||
resnet_output_shape=resnet_output_shape)
|
||||
output_shape = example_output.shape[1:]
|
||||
return output_shape
|
||||
473
src/unifolm_wma/models/samplers/ddim.py
Normal file
@@ -0,0 +1,473 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import copy
|
||||
|
||||
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
|
||||
from unifolm_wma.utils.common import noise_like
|
||||
from unifolm_wma.utils.common import extract_into_tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.counter = 0
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize="uniform",
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
|
||||
.device)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
|
||||
self.ddim_scale_arr_prev = torch.cat(
|
||||
[self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# Calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# DDIM sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
schedule_verbose=False,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
precision=None,
|
||||
fs=None,
|
||||
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
|
||||
# Check condition bs
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
try:
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
except:
|
||||
cbs = conditioning[list(
|
||||
conditioning.keys())[0]][0].shape[0]
|
||||
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S,
|
||||
ddim_discretize=timestep_spacing,
|
||||
ddim_eta=eta,
|
||||
verbose=schedule_verbose)
|
||||
|
||||
# Make shape
|
||||
if len(shape) == 3:
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
elif len(shape) == 4:
|
||||
C, T, H, W = shape
|
||||
size = (batch_size, C, T, H, W)
|
||||
|
||||
samples, actions, states, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
verbose=verbose,
|
||||
precision=precision,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
return samples, actions, states, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
verbose=True,
|
||||
precision=None,
|
||||
fs=None,
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
device = self.model.betas.device
|
||||
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
|
||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
else:
|
||||
img = x_T
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
img = img.to(dtype=torch.float16)
|
||||
action = action.to(dtype=torch.float16)
|
||||
state = state.to(dtype=torch.float16)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1) *
|
||||
self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {
|
||||
'x_inter': [img],
|
||||
'pred_x0': [img],
|
||||
'x_inter_action': [action],
|
||||
'pred_x0_action': [action],
|
||||
'x_inter_state': [state],
|
||||
'pred_x0_state': [state],
|
||||
}
|
||||
time_range = reversed(range(
|
||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
0]
|
||||
if verbose:
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
else:
|
||||
iterator = time_range
|
||||
|
||||
clean_cond = kwargs.pop("clean_cond", False)
|
||||
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
model_output_state,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
|
||||
return img, action, state, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self,
|
||||
x,
|
||||
x_action,
|
||||
x_state,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
uc_type=None,
|
||||
conditional_guidance_scale_temporal=None,
|
||||
mask=None,
|
||||
x0=None,
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
b, *_, device = *x.shape, x.device
|
||||
if x.dim() == 5:
|
||||
is_video = True
|
||||
else:
|
||||
is_video = False
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output, model_output_action, model_output_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, c, **kwargs) # unet denoiser
|
||||
else:
|
||||
# do_classifier_free_guidance
|
||||
if isinstance(c, torch.Tensor) or isinstance(c, dict):
|
||||
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, c, **kwargs)
|
||||
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, unconditional_conditioning,
|
||||
**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t_cond - e_t_uncond)
|
||||
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
||||
e_t_cond_action - e_t_uncond_action)
|
||||
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
||||
e_t_cond_state - e_t_uncond_state)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
model_output = rescale_noise_cfg(
|
||||
model_output, e_t_cond, guidance_rescale=guidance_rescale)
|
||||
model_output_action = rescale_noise_cfg(
|
||||
model_output_action,
|
||||
e_t_cond_action,
|
||||
guidance_rescale=guidance_rescale)
|
||||
model_output_state = rescale_noise_cfg(
|
||||
model_output_state,
|
||||
e_t_cond_state,
|
||||
guidance_rescale=guidance_rescale)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||
**corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if is_video:
|
||||
size = (b, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (b, 1, 1, 1)
|
||||
|
||||
a_t = torch.full(size, alphas[index], device=device)
|
||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(size,
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
scale_t = torch.full(size,
|
||||
self.ddim_scale_arr[index],
|
||||
device=device)
|
||||
prev_scale_t = torch.full(size,
|
||||
self.ddim_scale_arr_prev[index],
|
||||
device=device)
|
||||
rescale = (prev_scale_t / scale_t)
|
||||
pred_x0 *= rescale
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
|
||||
return x_prev, pred_x0, model_output_action, model_output_state
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps
|
||||
) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0], ),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback: callback(i)
|
||||
return x_dec
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) *
|
||||
noise)
|
||||
0
src/unifolm_wma/modules/__init__.py
Normal file
806
src/unifolm_wma/modules/attention.py
Normal file
@@ -0,0 +1,806 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from functools import partial
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
from unifolm_wma.utils.common import (
|
||||
checkpoint,
|
||||
exists,
|
||||
default,
|
||||
)
|
||||
from unifolm_wma.utils.basics import zero_module
|
||||
|
||||
|
||||
class RelativePosition(nn.Module):
|
||||
""" https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
|
||||
|
||||
def __init__(self, num_units, max_relative_position):
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.max_relative_position = max_relative_position
|
||||
self.embeddings_table = nn.Parameter(
|
||||
torch.Tensor(max_relative_position * 2 + 1, num_units))
|
||||
nn.init.xavier_uniform_(self.embeddings_table)
|
||||
|
||||
def forward(self, length_q, length_k):
|
||||
device = self.embeddings_table.device
|
||||
range_vec_q = torch.arange(length_q, device=device)
|
||||
range_vec_k = torch.arange(length_k, device=device)
|
||||
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
|
||||
distance_mat_clipped = torch.clamp(distance_mat,
|
||||
-self.max_relative_position,
|
||||
self.max_relative_position)
|
||||
final_mat = distance_mat_clipped + self.max_relative_position
|
||||
final_mat = final_mat.long()
|
||||
embeddings = self.embeddings_table[final_mat]
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
relative_position=False,
|
||||
temporal_length=None,
|
||||
video_length=None,
|
||||
agent_state_context_len=2,
|
||||
agent_action_context_len=16,
|
||||
image_cross_attention=False,
|
||||
image_cross_attention_scale=1.0,
|
||||
agent_state_cross_attention_scale=1.0,
|
||||
agent_action_cross_attention_scale=1.0,
|
||||
cross_attention_scale_learnable=False,
|
||||
text_context_len=77):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout))
|
||||
|
||||
self.relative_position = relative_position
|
||||
if self.relative_position:
|
||||
assert (temporal_length is not None)
|
||||
self.relative_position_k = RelativePosition(
|
||||
num_units=dim_head, max_relative_position=temporal_length)
|
||||
self.relative_position_v = RelativePosition(
|
||||
num_units=dim_head, max_relative_position=temporal_length)
|
||||
else:
|
||||
## only used for spatial attention, while NOT for temporal attention
|
||||
if XFORMERS_IS_AVAILBLE and temporal_length is None:
|
||||
self.forward = self.efficient_forward
|
||||
|
||||
self.video_length = video_length
|
||||
self.image_cross_attention = image_cross_attention
|
||||
self.image_cross_attention_scale = image_cross_attention_scale
|
||||
self.agent_state_cross_attention_scale = agent_state_cross_attention_scale
|
||||
self.agent_action_cross_attention_scale = agent_action_cross_attention_scale
|
||||
self.text_context_len = text_context_len
|
||||
self.agent_state_context_len = agent_state_context_len
|
||||
self.agent_action_context_len = agent_action_context_len
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
if self.image_cross_attention:
|
||||
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_k_as = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_as = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_k_aa = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_aa = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
if cross_attention_scale_learnable:
|
||||
self.register_parameter('alpha_ctx',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_cas',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_caa',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
k_as, v_as, out_as = None, None, None
|
||||
k_aa, v_aa, out_aa = None, None, None
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:,
|
||||
self.agent_state_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
if self.relative_position:
|
||||
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
|
||||
k2 = self.relative_position_k(len_q, len_k)
|
||||
sim2 = einsum('b t d, t s d -> b t s', q,
|
||||
k2) * self.scale # TODO check
|
||||
sim += sim2
|
||||
del k
|
||||
|
||||
if exists(mask):
|
||||
## feasible for causal attention mask only
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
|
||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
||||
if self.relative_position:
|
||||
v2 = self.relative_position_v(len_q, len_v)
|
||||
out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
|
||||
out += out2
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if k_ip is not None and k_as is not None and k_aa is not None:
|
||||
## for image cross-attention
|
||||
k_ip, v_ip = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_ip, v_ip))
|
||||
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_ip) * self.scale
|
||||
del k_ip
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## for agent state cross-attention
|
||||
k_as, v_as = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_as, v_as))
|
||||
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_as) * self.scale
|
||||
del k_as
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## for agent action cross-attention
|
||||
k_aa, v_aa = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_aa, v_aa))
|
||||
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_aa) * self.scale
|
||||
del k_aa
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if out_ip is not None and out_as is not None and out_aa is not None:
|
||||
if self.cross_attention_scale_learnable:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
||||
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
||||
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
||||
else:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip + \
|
||||
self.agent_state_cross_attention_scale * out_as + \
|
||||
self.agent_action_cross_attention_scale * out_aa
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def efficient_forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k, v, out = None, None, None
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
k_as, v_as, out_as = None, None, None
|
||||
k_aa, v_aa, out_aa = None, None, None
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
if context.shape[1] == self.text_context_len + self.video_length:
|
||||
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
else:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
|
||||
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
||||
q.shape[1],
|
||||
k_aa.shape[1],
|
||||
block_size=16).to(k_aa.device)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
|
||||
if k is not None:
|
||||
k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||
(k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out = (out.unsqueeze(0).reshape(
|
||||
b, self.heads, out.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
|
||||
if k_ip is not None:
|
||||
# For image cross-attention
|
||||
k_ip, v_ip = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_ip, v_ip),
|
||||
)
|
||||
out_ip = xformers.ops.memory_efficient_attention(q,
|
||||
k_ip,
|
||||
v_ip,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out_ip = (out_ip.unsqueeze(0).reshape(
|
||||
b, self.heads, out_ip.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_ip.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
|
||||
if k_as is not None:
|
||||
# For agent state cross-attention
|
||||
k_as, v_as = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_as, v_as),
|
||||
)
|
||||
out_as = xformers.ops.memory_efficient_attention(q,
|
||||
k_as,
|
||||
v_as,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out_as = (out_as.unsqueeze(0).reshape(
|
||||
b, self.heads, out_as.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_as.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
if k_aa is not None:
|
||||
# For agent action cross-attention
|
||||
k_aa, v_aa = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_aa, v_aa),
|
||||
)
|
||||
|
||||
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
|
||||
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
|
||||
attn_mask_aa = attn_mask_aa.to(q.dtype)
|
||||
|
||||
out_aa = xformers.ops.memory_efficient_attention(
|
||||
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
|
||||
|
||||
out_aa = (out_aa.unsqueeze(0).reshape(
|
||||
b, self.heads, out_aa.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_aa.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
|
||||
out = 0.0 if out is None else out
|
||||
out_ip = 0.0 if out_ip is None else out_ip
|
||||
out_as = 0.0 if out_as is None else out_as
|
||||
out_aa = 0.0 if out_aa is None else out_aa
|
||||
|
||||
if self.cross_attention_scale_learnable:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
||||
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
||||
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
||||
|
||||
else:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip + \
|
||||
self.agent_state_cross_attention_scale * out_as + \
|
||||
self.agent_action_cross_attention_scale * out_aa
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
|
||||
num_token = l2 // block_size
|
||||
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2)
|
||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||
attn_mask = torch.zeros_like(mask, dtype=torch.float)
|
||||
attn_mask[mask] = float('-inf')
|
||||
return attn_mask
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
attention_cls=None,
|
||||
video_length=None,
|
||||
agent_state_context_len=2,
|
||||
agent_action_context_len=16,
|
||||
image_cross_attention=False,
|
||||
image_cross_attention_scale=1.0,
|
||||
cross_attention_scale_learnable=False,
|
||||
text_context_len=77):
|
||||
super().__init__()
|
||||
attn_cls = CrossAttention if attention_cls is None else attention_cls
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None)
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
video_length=video_length,
|
||||
agent_state_context_len=agent_state_context_len,
|
||||
agent_action_context_len=agent_action_context_len,
|
||||
image_cross_attention=image_cross_attention,
|
||||
image_cross_attention_scale=image_cross_attention_scale,
|
||||
cross_attention_scale_learnable=cross_attention_scale_learnable,
|
||||
text_context_len=text_context_len)
|
||||
self.image_cross_attention = image_cross_attention
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None, mask=None, **kwargs):
|
||||
# implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
|
||||
input_tuple = (
|
||||
x,
|
||||
) # should not be (x), otherwise *input_tuple will decouple x into multiple arguments
|
||||
if context is not None:
|
||||
input_tuple = (x, context)
|
||||
if mask is not None:
|
||||
forward_mask = partial(self._forward, mask=mask)
|
||||
return checkpoint(forward_mask, (x, ), self.parameters(),
|
||||
self.checkpoint)
|
||||
return checkpoint(self._forward, input_tuple, self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, mask=None):
|
||||
x = self.attn1(self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
mask=mask) + x
|
||||
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data in spatial axis.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
use_checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
use_linear=False,
|
||||
video_length=None,
|
||||
agent_state_context_len=2,
|
||||
agent_action_context_len=16,
|
||||
image_cross_attention=False,
|
||||
cross_attention_scale_learnable=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
attention_cls = None
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
checkpoint=use_checkpoint,
|
||||
attention_cls=attention_cls,
|
||||
video_length=video_length,
|
||||
agent_state_context_len=agent_state_context_len,
|
||||
agent_action_context_len=agent_action_context_len,
|
||||
image_cross_attention=image_cross_attention,
|
||||
cross_attention_scale_learnable=cross_attention_scale_learnable,
|
||||
) for d in range(depth)
|
||||
])
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None, **kwargs):
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context, **kwargs)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class TemporalTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data in temporal axis.
|
||||
First, reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
use_checkpoint=True,
|
||||
use_linear=False,
|
||||
only_self_att=True,
|
||||
causal_attention=False,
|
||||
causal_block_size=1,
|
||||
relative_position=False,
|
||||
temporal_length=None):
|
||||
super().__init__()
|
||||
self.only_self_att = only_self_att
|
||||
self.relative_position = relative_position
|
||||
self.causal_attention = causal_attention
|
||||
self.causal_block_size = causal_block_size
|
||||
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.proj_in = nn.Conv1d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv1d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
if relative_position:
|
||||
assert (temporal_length is not None)
|
||||
attention_cls = partial(CrossAttention,
|
||||
relative_position=True,
|
||||
temporal_length=temporal_length)
|
||||
else:
|
||||
attention_cls = partial(CrossAttention,
|
||||
temporal_length=temporal_length)
|
||||
if self.causal_attention:
|
||||
assert (temporal_length is not None)
|
||||
self.mask = torch.tril(
|
||||
torch.ones([1, temporal_length, temporal_length]))
|
||||
|
||||
if self.only_self_att:
|
||||
context_dim = None
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
attention_cls=attention_cls,
|
||||
checkpoint=use_checkpoint)
|
||||
for d in range(depth)
|
||||
])
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv1d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None):
|
||||
b, c, t, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
|
||||
temp_mask = None
|
||||
if self.causal_attention:
|
||||
# Slice the from mask map
|
||||
temp_mask = self.mask[:, :t, :t].to(x.device)
|
||||
|
||||
if temp_mask is not None:
|
||||
mask = temp_mask.to(x.device)
|
||||
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b * h * w)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
if self.only_self_att:
|
||||
# NOTE: if no context is given, cross-attention defaults to self-attention
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, mask=mask)
|
||||
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
||||
else:
|
||||
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
||||
context = rearrange(context, '(b t) l con -> b t l con',
|
||||
t=t).contiguous()
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
# Calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
||||
for j in range(b):
|
||||
context_j = repeat(context[j],
|
||||
't l con -> (t r) l con',
|
||||
r=(h * w) // t,
|
||||
t=t).contiguous()
|
||||
# Note: causal mask will not applied in cross-attention case
|
||||
x[j] = block(x[j], context=context_j)
|
||||
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h,
|
||||
w=w).contiguous()
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv,
|
||||
'b (qkv heads c) h w -> qkv b heads c (h w)',
|
||||
heads=self.heads,
|
||||
qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out,
|
||||
'b heads c (h w) -> b (heads c) h w',
|
||||
heads=self.heads,
|
||||
h=h,
|
||||
w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# Compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# Attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
630
src/unifolm_wma/modules/encoders/condition.py
Normal file
@@ -0,0 +1,630 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import kornia
|
||||
import open_clip
|
||||
import math
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from unifolm_wma.utils.common import autocast
|
||||
from unifolm_wma.utils.utils import count_params
|
||||
from unifolm_wma.modules.encoders.resampler import reshape_tensor
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
self.n_classes = n_classes
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def forward(self, batch, key=None, disable_dropout=False):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes -
|
||||
1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
def get_unconditional_conditioning(self, bs, device="cuda"):
|
||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc = torch.ones((bs, ), device=device) * uc_class
|
||||
uc = {self.key: uc}
|
||||
return uc
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
|
||||
def __init__(self,
|
||||
version="google/t5-v1_1-xxl",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="last",
|
||||
layer_idx=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = layer_idx
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert 0 <= abs(layer_idx) <= 12
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens,
|
||||
output_hidden_states=self.layer == "hidden")
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class ClipImageEmbedder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
jit=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
antialias=True,
|
||||
ucg_rate=0.):
|
||||
super().__init__()
|
||||
from clip import load as load_clip
|
||||
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer('std',
|
||||
torch.Tensor([0.26862954, 0.26130258,
|
||||
0.27577711]),
|
||||
persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# re-normalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x, no_dropout=False):
|
||||
# x is assumed to be in range [-1,1]
|
||||
out = self.model.encode_image(self.preprocess(x))
|
||||
out = out.to(x.dtype)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
out = torch.bernoulli(
|
||||
(1. - self.ucg_rate) *
|
||||
torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||
return out
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
# "pooled",
|
||||
"last",
|
||||
"penultimate"
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
arch="ViT-H-14",
|
||||
version="laion2b_s32b_b79k",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="last"):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=version)
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "last":
|
||||
self.layer_idx = 0
|
||||
elif self.layer == "penultimate":
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(
|
||||
text) ## all clip models use 77 as context length
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch="ViT-H-14",
|
||||
version="laion2b_s32b_b79k",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="pooled",
|
||||
antialias=True,
|
||||
ucg_rate=0.):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch,
|
||||
device=torch.device('cpu'),
|
||||
pretrained=version,
|
||||
)
|
||||
del model.transformer
|
||||
self.model = model
|
||||
# self.mapper = torch.nn.Linear(1280, 1024)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "penultimate":
|
||||
raise NotImplementedError()
|
||||
self.layer_idx = 1
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer('std',
|
||||
torch.Tensor([0.26862954, 0.26130258,
|
||||
0.27577711]),
|
||||
persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@autocast
|
||||
def forward(self, image, no_dropout=False):
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
z = torch.bernoulli(
|
||||
(1. - self.ucg_rate) *
|
||||
torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, img):
|
||||
img = self.preprocess(img)
|
||||
x = self.model.visual(img)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch="ViT-H-14",
|
||||
version="laion2b_s32b_b79k",
|
||||
device="cuda",
|
||||
freeze=True,
|
||||
layer="pooled",
|
||||
antialias=True):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch,
|
||||
device=torch.device('cpu'),
|
||||
pretrained=version,
|
||||
)
|
||||
del model.transformer
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "penultimate":
|
||||
raise NotImplementedError()
|
||||
self.layer_idx = 1
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer('std',
|
||||
torch.Tensor([0.26862954, 0.26130258,
|
||||
0.27577711]),
|
||||
persistent=False)
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, image, no_dropout=False):
|
||||
## image: b c h w
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, x):
|
||||
x = self.preprocess(x)
|
||||
|
||||
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
||||
if self.model.visual.input_patchnorm:
|
||||
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
self.model.visual.grid_size[0],
|
||||
self.model.visual.patch_size[0],
|
||||
self.model.visual.grid_size[1],
|
||||
self.model.visual.patch_size[1])
|
||||
x = x.permute(0, 2, 4, 1, 3, 5)
|
||||
x = x.reshape(
|
||||
x.shape[0], self.model.visual.grid_size[0] *
|
||||
self.model.visual.grid_size[1], -1)
|
||||
x = self.model.visual.patchnorm_pre_ln(x)
|
||||
x = self.model.visual.conv1(x)
|
||||
else:
|
||||
x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
# class embeddings and positional embeddings
|
||||
x = torch.cat([
|
||||
self.model.visual.class_embedding.to(x.dtype) + torch.zeros(
|
||||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
|
||||
],
|
||||
dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.model.visual.positional_embedding.to(x.dtype)
|
||||
|
||||
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
||||
x = self.model.visual.patch_dropout(x)
|
||||
x = self.model.visual.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.model.visual.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
|
||||
def __init__(self,
|
||||
clip_version="openai/clip-vit-large-patch14",
|
||||
t5_version="google/t5-v1_1-xl",
|
||||
device="cuda",
|
||||
clip_max_length=77,
|
||||
t5_max_length=77):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version,
|
||||
device,
|
||||
max_length=clip_max_length)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version,
|
||||
device,
|
||||
max_length=t5_max_length)
|
||||
print(
|
||||
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
|
||||
)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
def forward(self, text):
|
||||
clip_z = self.clip_encoder.encode(text)
|
||||
t5_z = self.t5_encoder.encode(text)
|
||||
return [clip_z, t5_z]
|
||||
|
||||
|
||||
class LinearProjector(nn.Module):
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.projector = nn.Linear(input_dim, output_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class MLPProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
mlp_type: str = "gelu-mlp") -> None:
|
||||
super().__init__()
|
||||
if mlp_type == "gelu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(input_dim, output_dim, bias=True),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(output_dim, output_dim, bias=True),
|
||||
)
|
||||
elif mlp_type == "silu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(input_dim, output_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(output_dim, output_dim, bias=True),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(
|
||||
-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
|
||||
inner_dim = int(dim * mult)
|
||||
if ffd_type == "gelu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
elif ffd_type == "silu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
|
||||
class SATokenProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=1024,
|
||||
depth=1,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=16,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
chunk_size=None):
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
if chunk_size is not None:
|
||||
num_queries = num_queries * chunk_size
|
||||
|
||||
self.latents = nn.Parameter(
|
||||
torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head,
|
||||
heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
return latents
|
||||
153
src/unifolm_wma/modules/encoders/resampler.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
||||
# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ImageProjModel(nn.Module):
|
||||
"""Projection Model"""
|
||||
|
||||
def __init__(self,
|
||||
cross_attention_dim=1024,
|
||||
clip_embeddings_dim=1024,
|
||||
clip_extra_context_tokens=4):
|
||||
super().__init__()
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||
self.proj = nn.Linear(
|
||||
clip_embeddings_dim,
|
||||
self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds):
|
||||
#embeds = image_embeds
|
||||
embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
||||
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(
|
||||
-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
video_length=None, # using frame-wise version or not
|
||||
):
|
||||
super().__init__()
|
||||
## queries for a single frame / image
|
||||
self.num_queries = num_queries
|
||||
self.video_length = video_length
|
||||
|
||||
## <num_queries> queries for each frame
|
||||
if video_length is not None:
|
||||
num_queries = num_queries * video_length
|
||||
|
||||
self.latents = nn.Parameter(
|
||||
torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head,
|
||||
heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
|
||||
return latents
|
||||
1005
src/unifolm_wma/modules/networks/ae_modules.py
Normal file
848
src/unifolm_wma/modules/networks/wma_model.py
Normal file
@@ -0,0 +1,848 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from functools import partial
|
||||
from abc import abstractmethod
|
||||
from einops import rearrange
|
||||
from omegaconf import OmegaConf
|
||||
from typing import Optional, Sequence, Any, Tuple, Union, List, Dict
|
||||
from collections.abc import Mapping, Iterable, Callable
|
||||
|
||||
from unifolm_wma.utils.diffusion import timestep_embedding
|
||||
from unifolm_wma.utils.common import checkpoint
|
||||
from unifolm_wma.utils.basics import (zero_module, conv_nd, linear,
|
||||
avg_pool_nd, normalization)
|
||||
from unifolm_wma.modules.attention import SpatialTransformer, TemporalTransformer
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None, batch_size=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb, batch_size=batch_size)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
elif isinstance(layer, TemporalTransformer):
|
||||
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
|
||||
x = layer(x, context)
|
||||
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||
mode='nearest')
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
:param use_temporal_conv: if True, use the temporal convolution.
|
||||
:param use_image_dataset: if True, the temporal parameters will not be optimized.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
use_temporal_conv=False,
|
||||
tempspatial_aware=False):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.use_temporal_conv = use_temporal_conv
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels
|
||||
if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(dims,
|
||||
channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels,
|
||||
1)
|
||||
|
||||
if self.use_temporal_conv:
|
||||
self.temopral_conv = TemporalConvBlock(
|
||||
self.out_channels,
|
||||
self.out_channels,
|
||||
dropout=0.1,
|
||||
spatial_aware=tempspatial_aware)
|
||||
|
||||
def forward(self, x, emb, batch_size=None):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
input_tuple = (x, emb)
|
||||
if batch_size:
|
||||
forward_batchsize = partial(self._forward, batch_size=batch_size)
|
||||
return checkpoint(forward_batchsize, input_tuple,
|
||||
self.parameters(), self.use_checkpoint)
|
||||
return checkpoint(self._forward, input_tuple, self.parameters(),
|
||||
self.use_checkpoint)
|
||||
|
||||
def _forward(self, x, emb, batch_size=None):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
h = self.skip_connection(x) + h
|
||||
|
||||
if self.use_temporal_conv and batch_size:
|
||||
h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
|
||||
h = self.temopral_conv(h)
|
||||
h = rearrange(h, 'b c t h w -> (b t) c h w')
|
||||
return h
|
||||
|
||||
|
||||
class TemporalConvBlock(nn.Module):
|
||||
"""
|
||||
Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
dropout=0.0,
|
||||
spatial_aware=False):
|
||||
super(TemporalConvBlock, self).__init__()
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
|
||||
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
|
||||
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
|
||||
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
|
||||
|
||||
# conv layers
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.GroupNorm(32, in_channels), nn.SiLU(),
|
||||
nn.Conv3d(in_channels,
|
||||
out_channels,
|
||||
th_kernel_shape,
|
||||
padding=th_padding_shape))
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels,
|
||||
in_channels,
|
||||
tw_kernel_shape,
|
||||
padding=tw_padding_shape))
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels,
|
||||
in_channels,
|
||||
th_kernel_shape,
|
||||
padding=th_padding_shape))
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels,
|
||||
in_channels,
|
||||
tw_kernel_shape,
|
||||
padding=tw_padding_shape))
|
||||
|
||||
# Zero out the last layer params,so the conv block is identity
|
||||
nn.init.zeros_(self.conv4[-1].weight)
|
||||
nn.init.zeros_(self.conv4[-1].bias)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
|
||||
return identity + x
|
||||
|
||||
|
||||
class WMAModel(nn.Module):
|
||||
"""
|
||||
The full World-Model-Action model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
attention_resolutions: Sequence[int],
|
||||
dropout: float = 0.0,
|
||||
channel_mult: Sequence[int] = (1, 2, 4, 8),
|
||||
conv_resample: bool = True,
|
||||
dims: int = 2,
|
||||
context_dim: int | None = None,
|
||||
use_scale_shift_norm: bool = False,
|
||||
resblock_updown: bool = False,
|
||||
num_heads: int = -1,
|
||||
num_head_channels: int = -1,
|
||||
transformer_depth: int = 1,
|
||||
use_linear: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
temporal_conv: bool = False,
|
||||
tempspatial_aware: bool = False,
|
||||
temporal_attention: bool = True,
|
||||
use_relative_position: bool = True,
|
||||
use_causal_attention: bool = False,
|
||||
temporal_length: int | None = None,
|
||||
use_fp16: bool = False,
|
||||
addition_attention: bool = False,
|
||||
temporal_selfatt_only: bool = True,
|
||||
image_cross_attention: bool = False,
|
||||
cross_attention_scale_learnable: bool = False,
|
||||
default_fs: int = 4,
|
||||
fs_condition: bool = False,
|
||||
n_obs_steps: int = 1,
|
||||
num_stem_token: int = 1,
|
||||
unet_head_config: OmegaConf | None = None,
|
||||
stem_process_config: OmegaConf | None = None,
|
||||
base_model_gen_only: bool = False):
|
||||
"""
|
||||
Initialize the World-Model-Action network.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels to the backbone.
|
||||
model_channels: Base channel width for the UNet/backbone.
|
||||
out_channels: Number of output channels.
|
||||
num_res_blocks: Number of residual blocks per resolution stage.
|
||||
attention_resolutions: Resolutions at which to enable attention.
|
||||
dropout: Dropout probability used inside residual/attention blocks.
|
||||
channel_mult: Multipliers for channels at each resolution level.
|
||||
conv_resample: If True, use convolutional resampling for up/down sampling.
|
||||
dims: Spatial dimensionality of the backbone (1/2/3).
|
||||
context_dim: Optional context embedding dimension (for cross-attention).
|
||||
use_scale_shift_norm: Enable scale-shift (FiLM-style) normalization in blocks.
|
||||
resblock_updown: Use residual blocks for up/down sampling (instead of plain conv).
|
||||
num_heads: Number of attention heads (if >= 0). If -1, derive from num_head_channels.
|
||||
num_head_channels: Channels per attention head (if >= 0). If -1, derive from num_heads.
|
||||
transformer_depth: Number of transformer/attention blocks per stage.
|
||||
use_linear: Use linear attention variants where applicable.
|
||||
use_checkpoint: Enable gradient checkpointing in blocks to save memory.
|
||||
temporal_conv: Include temporal convolution along the time dimension.
|
||||
tempspatial_aware: If True, use time–space aware blocks.
|
||||
temporal_attention: Enable temporal self-attention.
|
||||
use_relative_position: Use relative position encodings in attention.
|
||||
use_causal_attention: Use causal (uni-directional) attention along time.
|
||||
temporal_length: Optional maximum temporal length expected by the model.
|
||||
use_fp16: Enable half-precision layers/normalization where supported.
|
||||
addition_attention: Add auxiliary attention modules.
|
||||
temporal_selfatt_only: Restrict attention to temporal-only (no spatial) if True.
|
||||
image_cross_attention: Enable cross-attention with image embeddings.
|
||||
cross_attention_scale_learnable: Make cross-attention scaling a learnable parameter.
|
||||
default_fs: Default frame-stride / fps.
|
||||
fs_condition: If True, condition on frame-stride/fps features.
|
||||
n_obs_steps: Number of observed steps used in conditioning heads.
|
||||
num_stem_token: Number of stem tokens for action tokenization.
|
||||
unet_head_config: OmegaConf for UNet heads (e.g., action/state heads).
|
||||
stem_process_config: OmegaConf for stem/preprocessor module.
|
||||
base_model_gen_only: Perform the generation using the base model with out action and state outputs.
|
||||
"""
|
||||
|
||||
super(WMAModel, self).__init__()
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.temporal_attention = temporal_attention
|
||||
time_embed_dim = model_channels * 4
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
temporal_self_att_only = True
|
||||
self.addition_attention = addition_attention
|
||||
self.temporal_length = temporal_length
|
||||
self.image_cross_attention = image_cross_attention
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
self.default_fs = default_fs
|
||||
self.fs_condition = fs_condition
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.num_stem_token = num_stem_token
|
||||
self.base_model_gen_only = base_model_gen_only
|
||||
|
||||
# Time embedding blocks
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
if fs_condition:
|
||||
self.fps_embedding = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
nn.init.zeros_(self.fps_embedding[-1].weight)
|
||||
nn.init.zeros_(self.fps_embedding[-1].bias)
|
||||
# Input Block
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
||||
])
|
||||
if self.addition_attention:
|
||||
self.init_attn = TimestepEmbedSequential(
|
||||
TemporalTransformer(model_channels,
|
||||
n_heads=8,
|
||||
d_head=num_head_channels,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_selfatt_only,
|
||||
causal_attention=False,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disable_self_attn=False,
|
||||
video_length=temporal_length,
|
||||
agent_state_context_len=self.n_obs_steps,
|
||||
agent_action_context_len=self.temporal_length *
|
||||
num_stem_token,
|
||||
image_cross_attention=self.image_cross_attention,
|
||||
cross_attention_scale_learnable=self.
|
||||
cross_attention_scale_learnable,
|
||||
))
|
||||
if self.temporal_attention:
|
||||
layers.append(
|
||||
TemporalTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_self_att_only,
|
||||
causal_attention=use_causal_attention,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True)
|
||||
if resblock_updown else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
layers = [
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv),
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disable_self_attn=False,
|
||||
video_length=temporal_length,
|
||||
agent_state_context_len=self.n_obs_steps,
|
||||
agent_action_context_len=self.temporal_length * num_stem_token,
|
||||
image_cross_attention=self.image_cross_attention,
|
||||
cross_attention_scale_learnable=self.
|
||||
cross_attention_scale_learnable)
|
||||
]
|
||||
if self.temporal_attention:
|
||||
layers.append(
|
||||
TemporalTransformer(ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_self_att_only,
|
||||
causal_attention=use_causal_attention,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
layers.append(
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv))
|
||||
|
||||
# Middle Block
|
||||
self.middle_block = TimestepEmbedSequential(*layers)
|
||||
|
||||
# Output Block
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disable_self_attn=False,
|
||||
video_length=temporal_length,
|
||||
agent_state_context_len=self.n_obs_steps,
|
||||
image_cross_attention=self.image_cross_attention,
|
||||
cross_attention_scale_learnable=self.
|
||||
cross_attention_scale_learnable))
|
||||
if self.temporal_attention:
|
||||
layers.append(
|
||||
TemporalTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_self_att_only,
|
||||
causal_attention=use_causal_attention,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True)
|
||||
if resblock_updown else Upsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(
|
||||
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
# Action and state prediction unet
|
||||
unet_head_config['params']['context_dims'] = [
|
||||
mult * model_channels for mult in channel_mult
|
||||
]
|
||||
self.action_unet = instantiate_from_config(unet_head_config)
|
||||
self.state_unet = instantiate_from_config(unet_head_config)
|
||||
|
||||
# Initialize action token_projector
|
||||
self.action_token_projector = instantiate_from_config(
|
||||
stem_process_config)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
x_action: Tensor,
|
||||
x_state: Tensor,
|
||||
timesteps: Tensor,
|
||||
context: Tensor | None = None,
|
||||
context_action: Tensor | None = None,
|
||||
features_adapter: Any = None,
|
||||
fs: Tensor | None = None,
|
||||
**kwargs) -> Tensor | tuple[Tensor, ...]:
|
||||
|
||||
"""
|
||||
Forward pass of the World-Model-Action backbone.
|
||||
|
||||
Args:
|
||||
x: Input tensor (latent video), shape (B, C,...).
|
||||
x_action: action stream input.
|
||||
x_state: state stream input.
|
||||
timesteps: Diffusion timesteps, shape (B,) or scalar Tensor.
|
||||
context: conditioning context for cross-attention.
|
||||
context_action: conditioning context specific to action/state (implementation-specific).
|
||||
features_adapter: module or dict to adapt intermediate features.
|
||||
fs: frame-stride / fps conditioning.
|
||||
|
||||
Returns:
|
||||
Tuple of Tensors for predictions:
|
||||
|
||||
"""
|
||||
|
||||
b, _, t, _, _ = x.shape
|
||||
t_emb = timestep_embedding(timesteps,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
|
||||
emb = emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
|
||||
# Combine emb
|
||||
if self.fs_condition:
|
||||
if fs is None:
|
||||
fs = torch.tensor([self.default_fs] * b,
|
||||
dtype=torch.long,
|
||||
device=x.device)
|
||||
fs_emb = timestep_embedding(fs,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
|
||||
fs_embed = self.fps_embedding(fs_emb)
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
emb = emb + fs_embed
|
||||
|
||||
h = x.type(self.dtype)
|
||||
adapter_idx = 0
|
||||
hs = []
|
||||
hs_a = []
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if id == 0 and self.addition_attention:
|
||||
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||
# plug-in adapter features
|
||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||
h = h + features_adapter[adapter_idx]
|
||||
adapter_idx += 1
|
||||
if id != 0:
|
||||
if isinstance(module[0], Downsample):
|
||||
hs_a.append(
|
||||
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs.append(h)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
if features_adapter is not None:
|
||||
assert len(
|
||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
hs_out = []
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(
|
||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs_out.append(h)
|
||||
h = h.type(x.dtype)
|
||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
y = self.out(h)
|
||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||
|
||||
if not self.base_model_gen_only:
|
||||
ba, _, _ = x_action.shape
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
# Predict state
|
||||
if b > 1:
|
||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
a_y = torch.zeros_like(x_action)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
return y, a_y, s_y
|
||||
0
src/unifolm_wma/modules/vision/__init__.py
Normal file
244
src/unifolm_wma/modules/vision/base_vision.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
base_vision.py
|
||||
|
||||
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
|
||||
functions, and initialization logic.
|
||||
|
||||
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
|
||||
Transformer model for feature extraction.
|
||||
"""
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as TVF
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from timm.models.vision_transformer import Block, VisionTransformer
|
||||
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
||||
from torchvision.transforms import Compose, Resize
|
||||
|
||||
|
||||
# === Utility Functions for Monkey-Patching ===
|
||||
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
||||
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
result = fn(*args, **kwargs)
|
||||
return result[0] if isinstance(result, tuple) else result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# === Interface for an Image Transform ===
|
||||
class ImageTransform(Protocol):
|
||||
|
||||
def __call__(
|
||||
self, img: Image,
|
||||
**kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
...
|
||||
|
||||
|
||||
# === Custom Torchvision Image Transforms ===
|
||||
@dataclass
|
||||
class LetterboxPad:
|
||||
padding_fill_value: Tuple[int, int, int]
|
||||
|
||||
def __call__(self, image: Image) -> Image:
|
||||
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
|
||||
(w, h), max_wh = image.size, max(image.size)
|
||||
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int(
|
||||
(max_wh - h) / 2)
|
||||
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
|
||||
return TVF.pad(image,
|
||||
padding,
|
||||
fill=self.padding_fill_value,
|
||||
padding_mode="constant")
|
||||
|
||||
|
||||
# === Abstract Base Class for arbitrary Vision Backbones ===
|
||||
class VisionBackbone(nn.Module, ABC):
|
||||
|
||||
def __init__(self,
|
||||
vision_backbone_id: str,
|
||||
image_resize_strategy: str,
|
||||
default_image_size: int = 224) -> None:
|
||||
super().__init__()
|
||||
self.identifier: str = vision_backbone_id
|
||||
self.image_resize_strategy: str = image_resize_strategy
|
||||
self.default_image_size: int = default_image_size
|
||||
|
||||
# Instance attributes for a Vision Backbone
|
||||
self.featurizer: nn.Module = None
|
||||
self.image_transform: ImageTransform = None
|
||||
|
||||
def get_image_transform(self) -> ImageTransform:
|
||||
return self.image_transform
|
||||
|
||||
@abstractmethod
|
||||
def get_fsdp_wrapping_policy(self) -> Callable:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def default_image_resolution(self) -> Tuple[int, int, int]:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def embed_dim(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_patches(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def half_precision_dtype(self) -> torch.dtype:
|
||||
...
|
||||
|
||||
|
||||
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
|
||||
class TimmViTBackbone(VisionBackbone, ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone_id: str,
|
||||
timm_path_or_url: str,
|
||||
image_resize_strategy: str,
|
||||
default_image_size: int = 224,
|
||||
override_act_layer: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(vision_backbone_id,
|
||||
image_resize_strategy,
|
||||
default_image_size=default_image_size)
|
||||
self.timm_path_or_url = timm_path_or_url
|
||||
self.override_act_layer = override_act_layer
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
|
||||
if self.override_act_layer is None:
|
||||
self.featurizer: VisionTransformer = timm.create_model(
|
||||
self.timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size)
|
||||
else:
|
||||
self.featurizer: VisionTransformer = timm.create_model(
|
||||
self.timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size,
|
||||
act_layer=self.override_act_layer,
|
||||
)
|
||||
self.featurizer.eval()
|
||||
|
||||
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
|
||||
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
||||
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
||||
self.featurizer.forward = unpack_tuple(
|
||||
partial(self.featurizer.get_intermediate_layers,
|
||||
n={len(self.featurizer.blocks) - 2}))
|
||||
|
||||
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
|
||||
assert isinstance(self.featurizer, VisionTransformer), (
|
||||
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
|
||||
"file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
|
||||
)
|
||||
|
||||
# Get Config =>> Note :: Override default image size to ensure correct image transform
|
||||
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
|
||||
self.data_cfg["input_size"] = (3, self.default_image_size,
|
||||
self.default_image_size)
|
||||
|
||||
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
|
||||
default_image_transform = timm.data.create_transform(**self.data_cfg,
|
||||
is_training=False)
|
||||
|
||||
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
|
||||
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
|
||||
assert isinstance(default_image_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert isinstance(default_image_transform.transforms[0], Resize)
|
||||
default_image_transform = Compose([
|
||||
Resize(self.default_image_size,
|
||||
interpolation=default_image_transform.transforms[0].
|
||||
interpolation),
|
||||
*default_image_transform.transforms[1:],
|
||||
])
|
||||
|
||||
# Switch on `image_resize_strategy`
|
||||
if self.image_resize_strategy == "resize-naive":
|
||||
assert isinstance(default_image_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert isinstance(default_image_transform.transforms[0], Resize)
|
||||
|
||||
target_size = (self.default_image_size, self.default_image_size)
|
||||
self.image_transform = Compose([
|
||||
Resize(target_size,
|
||||
interpolation=default_image_transform.transforms[0].
|
||||
interpolation),
|
||||
*default_image_transform.transforms[1:],
|
||||
])
|
||||
|
||||
elif self.image_resize_strategy == "resize-crop":
|
||||
self.image_transform = default_image_transform
|
||||
|
||||
elif self.image_resize_strategy == "letterbox":
|
||||
assert isinstance(default_image_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
|
||||
|
||||
# Compute Padding Fill Value (rescaled normalization mean if applicable)
|
||||
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
|
||||
|
||||
# Build New Transform
|
||||
self.image_transform = Compose(
|
||||
[LetterboxPad(fill), *default_image_transform.transforms])
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
|
||||
)
|
||||
|
||||
def get_fsdp_wrapping_policy(self) -> Callable:
|
||||
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
|
||||
vit_wrap_policy = partial(_module_wrap_policy,
|
||||
module_classes={VisionTransformer})
|
||||
transformer_block_policy = partial(transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={Block})
|
||||
return partial(_or_policy,
|
||||
policies=[vit_wrap_policy, transformer_block_policy])
|
||||
|
||||
def forward(
|
||||
self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
|
||||
return self.featurizer(pixel_values)
|
||||
|
||||
@property
|
||||
def default_image_resolution(self) -> Tuple[int, int, int]:
|
||||
return self.data_cfg["input_size"]
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self.featurizer.embed_dim
|
||||
|
||||
@property
|
||||
def num_patches(self) -> int:
|
||||
return self.featurizer.patch_embed.num_patches
|
||||
|
||||
@property
|
||||
def half_precision_dtype(self) -> torch.dtype:
|
||||
return self.dtype
|
||||
273
src/unifolm_wma/modules/vision/dinosiglip_vit.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
dinosiglip_vit.py
|
||||
|
||||
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
|
||||
"""
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Tuple
|
||||
from PIL import Image
|
||||
from timm.models.vision_transformer import Block, VisionTransformer
|
||||
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
||||
from torchvision.transforms import Compose, Resize, Normalize
|
||||
|
||||
from unifolm_wma.modules.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
|
||||
from unifolm_wma.utils.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
|
||||
|
||||
# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
|
||||
DINOSigLIP_VISION_BACKBONES = {
|
||||
"dinosiglip-vit-so-224px": {
|
||||
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
|
||||
"siglip": "vit_so400m_patch14_siglip_224",
|
||||
},
|
||||
"dinosiglip-vit-so-384px": {
|
||||
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
|
||||
"siglip": "vit_so400m_patch14_siglip_384",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DinoSigLIPImageTransform:
|
||||
dino_image_transform: ImageTransform
|
||||
siglip_image_transform: ImageTransform
|
||||
is_prismatic: bool = True
|
||||
|
||||
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
|
||||
return {
|
||||
"dino": self.dino_image_transform(img, **kwargs),
|
||||
"siglip": self.siglip_image_transform(img, **kwargs)
|
||||
}
|
||||
|
||||
|
||||
class DinoSigLIPViTBackbone(VisionBackbone):
|
||||
|
||||
def __init__(self,
|
||||
vision_backbone_id: str,
|
||||
image_resize_strategy: str,
|
||||
arch_specifier: str,
|
||||
output_dim: int,
|
||||
pretrained_checkpoint=None,
|
||||
freeze=True,
|
||||
default_image_size: int = 224) -> None:
|
||||
super().__init__(vision_backbone_id,
|
||||
image_resize_strategy,
|
||||
default_image_size=default_image_size)
|
||||
self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
|
||||
vision_backbone_id]["dino"]
|
||||
self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
|
||||
vision_backbone_id]["siglip"]
|
||||
|
||||
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
|
||||
self.dino_featurizer: VisionTransformer = timm.create_model(
|
||||
self.dino_timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size)
|
||||
if pretrained_checkpoint:
|
||||
ckpt = pretrained_checkpoint + '/openvla_dino.pt'
|
||||
self.dino_featurizer.load_state_dict(
|
||||
torch.load(ckpt, weights_only=True))
|
||||
print('>>> load dino weights')
|
||||
if freeze:
|
||||
self.dino_featurizer.eval()
|
||||
for param in self.dino_featurizer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.siglip_featurizer: VisionTransformer = timm.create_model(
|
||||
self.siglip_timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size)
|
||||
if pretrained_checkpoint:
|
||||
ckpt = pretrained_checkpoint + '/openvla_siglip.pt'
|
||||
self.siglip_featurizer.load_state_dict(
|
||||
torch.load(ckpt, weights_only=True))
|
||||
print('>>> load siglip weights')
|
||||
if freeze:
|
||||
self.siglip_featurizer.eval()
|
||||
for param in self.siglip_featurizer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
|
||||
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
||||
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
||||
self.dino_featurizer.forward = unpack_tuple(
|
||||
partial(self.dino_featurizer.get_intermediate_layers,
|
||||
n={len(self.dino_featurizer.blocks) - 2}))
|
||||
self.siglip_featurizer.forward = unpack_tuple(
|
||||
partial(self.siglip_featurizer.get_intermediate_layers,
|
||||
n={len(self.siglip_featurizer.blocks) - 2}))
|
||||
|
||||
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
|
||||
self.dino_data_cfg = timm.data.resolve_model_data_config(
|
||||
self.dino_featurizer)
|
||||
self.dino_data_cfg["input_size"] = (3, self.default_image_size,
|
||||
self.default_image_size)
|
||||
|
||||
self.siglip_data_cfg = timm.data.resolve_model_data_config(
|
||||
self.siglip_featurizer)
|
||||
self.siglip_data_cfg["input_size"] = (3, self.default_image_size,
|
||||
self.default_image_size)
|
||||
|
||||
# Initialize *both* Transforms
|
||||
self.default_dino_transform = timm.data.create_transform(
|
||||
**self.dino_data_cfg, is_training=False)
|
||||
self.default_siglip_transform = timm.data.create_transform(
|
||||
**self.siglip_data_cfg, is_training=False)
|
||||
|
||||
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
|
||||
assert isinstance(self.default_siglip_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert isinstance(self.default_siglip_transform.transforms[0], Resize)
|
||||
self.default_siglip_transform = Compose([
|
||||
Resize(self.default_image_size,
|
||||
interpolation=self.default_siglip_transform.transforms[0].
|
||||
interpolation),
|
||||
*self.default_siglip_transform.transforms[1:],
|
||||
])
|
||||
|
||||
if self.image_resize_strategy == "resize-naive":
|
||||
assert isinstance(
|
||||
self.default_dino_transform,
|
||||
Compose), "Unexpected `default_dino_image_transform`!"
|
||||
assert isinstance(
|
||||
self.default_siglip_transform,
|
||||
Compose), "Unexpected `default_siglip_image_transform`!"
|
||||
assert isinstance(self.default_dino_transform.transforms[0],
|
||||
Resize)
|
||||
assert isinstance(self.default_siglip_transform.transforms[0],
|
||||
Resize)
|
||||
|
||||
self.target_size = (self.default_image_size,
|
||||
self.default_image_size)
|
||||
dino_transform = Compose([
|
||||
Resize(self.target_size,
|
||||
interpolation=self.default_dino_transform.transforms[0].
|
||||
interpolation),
|
||||
*self.default_dino_transform.transforms[1:],
|
||||
])
|
||||
siglip_transform = Compose([
|
||||
Resize(self.target_size,
|
||||
interpolation=self.default_siglip_transform.
|
||||
transforms[0].interpolation),
|
||||
*self.default_siglip_transform.transforms[1:],
|
||||
])
|
||||
|
||||
self.image_transform = DinoSigLIPImageTransform(
|
||||
dino_transform, siglip_transform)
|
||||
|
||||
elif self.image_resize_strategy == "resize-crop":
|
||||
self.image_transform = DinoSigLIPImageTransform(
|
||||
self.default_dino_transform, self.default_siglip_transform)
|
||||
|
||||
elif self.image_resize_strategy == "letterbox":
|
||||
assert isinstance(self.default_dino_transform,
|
||||
Compose), "Unexpected `default_dino_transform`!"
|
||||
assert isinstance(
|
||||
self.default_siglip_transform,
|
||||
Compose), "Unexpected `default_siglip_transform`!"
|
||||
assert ("mean" in self.dino_data_cfg
|
||||
and "mean" in self.siglip_data_cfg
|
||||
), "DinoSigLIP `data_cfg` missing `mean`!"
|
||||
|
||||
# Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
|
||||
dino_fill = tuple(
|
||||
[int(x * 255) for x in self.dino_data_cfg["mean"]])
|
||||
siglip_fill = tuple(
|
||||
[int(x * 255) for x in self.siglip_data_cfg["mean"]])
|
||||
|
||||
# Build New Transform
|
||||
self.image_transform = DinoSigLIPImageTransform(
|
||||
Compose([
|
||||
LetterboxPad(dino_fill),
|
||||
*self.default_dino_transform.transforms
|
||||
]),
|
||||
Compose([
|
||||
LetterboxPad(siglip_fill),
|
||||
*self.default_siglip_transform.transforms
|
||||
]),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
|
||||
)
|
||||
|
||||
self.arch_specifier = arch_specifier
|
||||
if arch_specifier == "linear":
|
||||
self.projector = LinearProjector(self.embed_dim, output_dim)
|
||||
elif arch_specifier.endswith("fused-gelu-mlp"):
|
||||
self.projector = FusedMLPProjector(self.embed_dim, output_dim)
|
||||
elif arch_specifier.endswith("gelu-mlp"):
|
||||
self.projector = MLPProjector(self.embed_dim, output_dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"PrismaticVLM with `{arch_specifier = }` is not supported!")
|
||||
|
||||
self.on_gpu = False
|
||||
|
||||
def get_fsdp_wrapping_policy(self) -> Callable:
|
||||
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
|
||||
vit_wrap_policy = partial(_module_wrap_policy,
|
||||
module_classes={VisionTransformer})
|
||||
transformer_block_policy = partial(transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={Block})
|
||||
return partial(_or_policy,
|
||||
policies=[vit_wrap_policy, transformer_block_policy])
|
||||
|
||||
def forward(self, img) -> torch.Tensor:
|
||||
img = torch.clamp(img.float(), -1., 1.)
|
||||
img = (img + 1.0) / 2.0
|
||||
img = img * 255
|
||||
|
||||
resize = transforms.Resize(min(self.target_size),
|
||||
interpolation=self.default_dino_transform.
|
||||
transforms[0].interpolation,
|
||||
max_size=None,
|
||||
antialias=True)
|
||||
center_crop = transforms.CenterCrop(self.target_size)
|
||||
img = center_crop(resize(img))
|
||||
|
||||
dino_normalizer = Normalize(mean=torch.tensor([0.4850, 0.4560,
|
||||
0.4060]),
|
||||
std=torch.tensor([0.2290, 0.2240, 0.2250]))
|
||||
siglip_normalizer = Normalize(
|
||||
mean=torch.tensor([0.5000, 0.5000, 0.5000]),
|
||||
std=torch.tensor([0.5000, 0.5000, 0.5000]))
|
||||
pixel_values = {
|
||||
'dino': dino_normalizer(img),
|
||||
'siglip': siglip_normalizer(img)
|
||||
}
|
||||
|
||||
if self.on_gpu:
|
||||
pixel_values = {k: v.cuda() for k, v in pixel_values.items()}
|
||||
elif next(self.dino_featurizer.parameters()).device.type != 'cpu':
|
||||
self.on_gpu = True
|
||||
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
|
||||
dino_patches = self.dino_featurizer(pixel_values["dino"])
|
||||
siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
|
||||
|
||||
return self.projector(torch.cat([dino_patches, siglip_patches], dim=2))
|
||||
|
||||
@property
|
||||
def default_image_resolution(self) -> Tuple[int, int, int]:
|
||||
return self.dino_data_cfg["input_size"]
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
|
||||
|
||||
@property
|
||||
def num_patches(self) -> int:
|
||||
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
|
||||
return self.dino_featurizer.patch_embed.num_patches
|
||||
|
||||
@property
|
||||
def half_precision_dtype(self) -> torch.dtype:
|
||||
return torch.bfloat16
|
||||