AudioX
This commit is contained in:
parent
1b0f86cde5
commit
1bf5691a38
49 changed files with 13764 additions and 7 deletions
180
.gitignore
vendored
180
.gitignore
vendored
|
|
@ -1,2 +1,178 @@
|
||||||
# /static/videos/*.mp4
|
# Byte-compiled / optimized / DLL files
|
||||||
# /static/videos/*.mov
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
*.ckpt
|
||||||
|
*.wav
|
||||||
|
# *.mp4
|
||||||
|
*.mp3
|
||||||
|
*.jsonl
|
||||||
|
wandb/*
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model/
|
||||||
|
logs/
|
||||||
|
log/
|
||||||
|
saved_ckpt/
|
||||||
|
wandb/
|
||||||
|
data/
|
||||||
|
demo_result/
|
||||||
|
model/
|
||||||
22
LICENSE
Normal file
22
LICENSE
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Stability AI
|
||||||
|
Copyright (c) 2025 AudioX, HKUST
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
143
README.md
143
README.md
|
|
@ -1,14 +1,20 @@
|
||||||
# AudioX: Diffusion Transformer for Anything-to-Audio Generation
|
# 🎧 AudioX: Diffusion Transformer for Anything-to-Audio Generation
|
||||||
|
|
||||||
|
[](https://arxiv.org/abs/2503.10522)
|
||||||
|
[](https://zeyuet.github.io/AudioX/)
|
||||||
|
[](https://huggingface.co/HKUSTAudio/AudioX)
|
||||||
|
|
||||||
|
|
||||||
[](https://arxiv.org/pdf/2503.10522) [](https://zeyuet.github.io/AudioX/)
|
---
|
||||||
|
|
||||||
|
**This is the official repository for "[AudioX: Diffusion Transformer for Anything-to-Audio Generation](https://arxiv.org/pdf/2503.10522)".**
|
||||||
|
|
||||||
**This is the repository for "AudioX: Diffusion Transformer for Anything-to-Audio Generation".**
|
|
||||||
|
|
||||||
## 📺 Demo Video
|
## 📺 Demo Video
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/0d8dd927-ff0f-4b35-ab1f-b3c3915017be
|
https://github.com/user-attachments/assets/0d8dd927-ff0f-4b35-ab1f-b3c3915017be
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
## ✨ Abstract
|
## ✨ Abstract
|
||||||
|
|
@ -34,8 +40,135 @@ Audio and music generation have emerged as crucial tasks in many applications, y
|
||||||
|
|
||||||
|
|
||||||
## Code
|
## Code
|
||||||
To be released.
|
|
||||||
|
|
||||||
|
|
||||||
<hr>
|
### 🛠️ Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/ZeyueT/AudioX.git
|
||||||
|
cd AudioX
|
||||||
|
conda create -n AudioX python=3.8.20
|
||||||
|
conda activate AudioX
|
||||||
|
pip install git+https://github.com/ZeyueT/AudioX.git
|
||||||
|
conda install -c conda-forge ffmpeg libsndfile
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🪄 Pretrained Checkpoints
|
||||||
|
|
||||||
|
Download the pretrained model from 🤗 [AudioX on Hugging Face](https://huggingface.co/HKUSTAudio/AudioX):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir -p model
|
||||||
|
wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/model.ckpt -O model/model.ckpt
|
||||||
|
wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/config.json -O model/config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🤗 Gradio Demo
|
||||||
|
|
||||||
|
To launch the Gradio demo locally, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 run_gradio.py \
|
||||||
|
--model-config model/config.json \
|
||||||
|
--share
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 🎯 Prompt Configuration Examples
|
||||||
|
|
||||||
|
| Task | `video_path` | `text_prompt` | `audio_path` |
|
||||||
|
|:---------------------|:-------------------|:----------------------------------------------|:-------------|
|
||||||
|
| Text-to-Audio (T2A) | `None` | `"Typing on a keyboard"` | `None` |
|
||||||
|
| Text-to-Music (T2M) | `None` | `"A music with piano and violin"` | `None` |
|
||||||
|
| Video-to-Audio (V2A) | `"video_path.mp4"` | `"Generate general audio for the video"` | `None` |
|
||||||
|
| Video-to-Music (V2M) | `"video_path.mp4"` | `"Generate music for the video"` | `None` |
|
||||||
|
| TV-to-Audio (TV2A) | `"video_path.mp4"` | `"Ocean waves crashing with people laughing"` | `None` |
|
||||||
|
| TV-to-Music (TV2M) | `"video_path.mp4"` | `"Generate music with piano instrument"` | `None` |
|
||||||
|
|
||||||
|
### 🖥️ Script Inference
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from einops import rearrange
|
||||||
|
from stable_audio_tools import get_pretrained_model
|
||||||
|
from stable_audio_tools.inference.generation import generate_diffusion_cond
|
||||||
|
from stable_audio_tools.data.utils import read_video, merge_video_audio
|
||||||
|
from stable_audio_tools.data.utils import load_and_process_audio
|
||||||
|
import os
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
# Download model
|
||||||
|
model, model_config = get_pretrained_model("HKUSTAudio/AudioX")
|
||||||
|
sample_rate = model_config["sample_rate"]
|
||||||
|
sample_size = model_config["sample_size"]
|
||||||
|
target_fps = model_config["video_fps"]
|
||||||
|
seconds_start = 0
|
||||||
|
seconds_total = 10
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# for video-to-music generation
|
||||||
|
video_path = "example/V2M_sample-1.mp4"
|
||||||
|
text_prompt = "Generate music for the video"
|
||||||
|
audio_path = None
|
||||||
|
|
||||||
|
video_tensor = read_video(video_path, seek_time=0, duration=seconds_total, target_fps=target_fps)
|
||||||
|
audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total)
|
||||||
|
|
||||||
|
conditioning = [{
|
||||||
|
"video_prompt": [video_tensor.unsqueeze(0)],
|
||||||
|
"text_prompt": text_prompt,
|
||||||
|
"audio_prompt": audio_tensor.unsqueeze(0),
|
||||||
|
"seconds_start": seconds_start,
|
||||||
|
"seconds_total": seconds_total
|
||||||
|
}]
|
||||||
|
|
||||||
|
# Generate stereo audio
|
||||||
|
output = generate_diffusion_cond(
|
||||||
|
model,
|
||||||
|
steps=250,
|
||||||
|
cfg_scale=7,
|
||||||
|
conditioning=conditioning,
|
||||||
|
sample_size=sample_size,
|
||||||
|
sigma_min=0.3,
|
||||||
|
sigma_max=500,
|
||||||
|
sampler_type="dpmpp-3m-sde",
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rearrange audio batch to a single sequence
|
||||||
|
output = rearrange(output, "b d n -> d (b n)")
|
||||||
|
|
||||||
|
# Peak normalize, clip, convert to int16, and save to file
|
||||||
|
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
||||||
|
torchaudio.save("output.wav", output, sample_rate)
|
||||||
|
|
||||||
|
if video_path is not None and os.path.exists(video_path):
|
||||||
|
merge_video_audio(video_path, "output.wav", "output.mp4", 0, seconds_total)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 🚀 Citation
|
||||||
|
|
||||||
|
If you find our work useful, please consider citing:
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{tian2025audiox,
|
||||||
|
title={AudioX: Diffusion Transformer for Anything-to-Audio Generation},
|
||||||
|
author={Tian, Zeyue and Jin, Yizhu and Liu, Zhaoyang and Yuan, Ruibin and Tan, Xu and Chen, Qifeng and Xue, Wei and Guo, Yike},
|
||||||
|
journal={arXiv preprint arXiv:2503.10522},
|
||||||
|
year={2025}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📭 Contact
|
||||||
|
|
||||||
|
If you have any comments or questions, feel free to contact Zeyue Tian(ztianad@connect.ust.hk).
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Please follow [MIT License](./LICENSE).
|
||||||
56
defaults.ini
Normal file
56
defaults.ini
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
|
||||||
|
[DEFAULTS]
|
||||||
|
|
||||||
|
#name of the run
|
||||||
|
name = stable_audio_tools
|
||||||
|
|
||||||
|
# the batch size
|
||||||
|
batch_size = 8
|
||||||
|
|
||||||
|
# number of GPUs to use for training
|
||||||
|
num_gpus = 1
|
||||||
|
|
||||||
|
# number of nodes to use for training
|
||||||
|
num_nodes = 1
|
||||||
|
|
||||||
|
# Multi-GPU strategy for PyTorch Lightning
|
||||||
|
strategy = ""
|
||||||
|
|
||||||
|
# Precision to use for training
|
||||||
|
precision = "16-mixed"
|
||||||
|
|
||||||
|
# number of CPU workers for the DataLoader
|
||||||
|
num_workers = 8
|
||||||
|
|
||||||
|
# the random seed
|
||||||
|
seed = 42
|
||||||
|
|
||||||
|
# Batches for gradient accumulation
|
||||||
|
accum_batches = 1
|
||||||
|
|
||||||
|
# Number of steps between checkpoints
|
||||||
|
checkpoint_every = 10000
|
||||||
|
|
||||||
|
# trainer checkpoint file to restart training from
|
||||||
|
ckpt_path = ''
|
||||||
|
|
||||||
|
# model checkpoint file to start a new training run from
|
||||||
|
pretrained_ckpt_path = ''
|
||||||
|
|
||||||
|
# Checkpoint path for the pretransform model if needed
|
||||||
|
pretransform_ckpt_path = ''
|
||||||
|
|
||||||
|
# configuration model specifying model hyperparameters
|
||||||
|
model_config = ''
|
||||||
|
|
||||||
|
# configuration for datasets
|
||||||
|
dataset_config = ''
|
||||||
|
|
||||||
|
# directory to save the checkpoints in
|
||||||
|
save_dir = ''
|
||||||
|
|
||||||
|
# gradient_clip_val passed into PyTorch Lightning Trainer
|
||||||
|
gradient_clip_val = 0.0
|
||||||
|
|
||||||
|
# remove the weight norm from the pretransform model
|
||||||
|
remove_pretransform_weight_norm = ''
|
||||||
BIN
example/V2A_sample-1.mp4
Normal file
BIN
example/V2A_sample-1.mp4
Normal file
Binary file not shown.
BIN
example/V2A_sample-2.mp4
Normal file
BIN
example/V2A_sample-2.mp4
Normal file
Binary file not shown.
BIN
example/V2A_sample-3.mp4
Normal file
BIN
example/V2A_sample-3.mp4
Normal file
Binary file not shown.
BIN
example/V2M_sample-1.mp4
Normal file
BIN
example/V2M_sample-1.mp4
Normal file
Binary file not shown.
BIN
example/V2M_sample-2.mp4
Normal file
BIN
example/V2M_sample-2.mp4
Normal file
Binary file not shown.
BIN
example/V2M_sample-3.mp4
Normal file
BIN
example/V2M_sample-3.mp4
Normal file
Binary file not shown.
3
pyproject.toml
Normal file
3
pyproject.toml
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
32
run_gradio.py
Normal file
32
run_gradio.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
from stable_audio_tools import get_pretrained_model
|
||||||
|
from stable_audio_tools.interface.gradio import create_ui
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
interface = create_ui(
|
||||||
|
model_config_path = args.model_config,
|
||||||
|
ckpt_path=args.ckpt_path,
|
||||||
|
pretrained_name=args.pretrained_name,
|
||||||
|
pretransform_ckpt_path=args.pretransform_ckpt_path,
|
||||||
|
model_half=args.model_half
|
||||||
|
)
|
||||||
|
interface.queue()
|
||||||
|
interface.launch(share=args.share, auth=(args.username, args.password) if args.username is not None else None)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description='Run gradio interface')
|
||||||
|
parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
|
||||||
|
parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
|
||||||
|
parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
|
||||||
|
parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
|
||||||
|
parser.add_argument('--share', action='store_true', help='Create a publicly shareable link', required=False)
|
||||||
|
parser.add_argument('--username', type=str, help='Gradio username', required=False)
|
||||||
|
parser.add_argument('--password', type=str, help='Gradio password', required=False)
|
||||||
|
parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
47
setup.py
Normal file
47
setup.py
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='AudioX',
|
||||||
|
version='0.1.0',
|
||||||
|
url='https://github.com/ZeyueT/AudioX.git',
|
||||||
|
author='AudioX, HKUST',
|
||||||
|
description='Training and inference tools for generative audio models from AudioX',
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=[
|
||||||
|
'aeiou',
|
||||||
|
'alias-free-torch==0.0.6',
|
||||||
|
'auraloss==0.4.0',
|
||||||
|
'descript-audio-codec==1.0.0',
|
||||||
|
'decord==0.6.0',
|
||||||
|
'einops',
|
||||||
|
'einops_exts',
|
||||||
|
'ema-pytorch==0.2.3',
|
||||||
|
'encodec==0.1.1',
|
||||||
|
'gradio==4.44.1',
|
||||||
|
'gradio_client==1.3.0',
|
||||||
|
'huggingface_hub',
|
||||||
|
'importlib-resources==5.12.0',
|
||||||
|
'k-diffusion==0.1.1',
|
||||||
|
'laion-clap==1.1.6',
|
||||||
|
'local-attention==1.8.6',
|
||||||
|
'pandas==2.0.2',
|
||||||
|
'pedalboard==0.9.14',
|
||||||
|
'prefigure==0.0.9',
|
||||||
|
'pytorch_lightning==2.4.0',
|
||||||
|
'PyWavelets==1.4.1',
|
||||||
|
'safetensors',
|
||||||
|
'sentencepiece==0.1.99',
|
||||||
|
'torch==2.4.1',
|
||||||
|
'torchaudio==2.4.1',
|
||||||
|
'torchmetrics==1.5.2',
|
||||||
|
'tqdm',
|
||||||
|
'transformers==4.46.2',
|
||||||
|
'v-diffusion-pytorch==0.0.2',
|
||||||
|
'vector-quantize-pytorch==1.9.14',
|
||||||
|
'wandb',
|
||||||
|
'webdataset==0.2.48',
|
||||||
|
'x-transformers==1.42.11',
|
||||||
|
'flash_attn'
|
||||||
|
],
|
||||||
|
|
||||||
|
)
|
||||||
2
stable_audio_tools/__init__.py
Normal file
2
stable_audio_tools/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .models.factory import create_model_from_config, create_model_from_config_path
|
||||||
|
from .models.pretrained import get_pretrained_model
|
||||||
0
stable_audio_tools/inference/__init__.py
Normal file
0
stable_audio_tools/inference/__init__.py
Normal file
275
stable_audio_tools/inference/generation.py
Normal file
275
stable_audio_tools/inference/generation.py
Normal file
|
|
@ -0,0 +1,275 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import typing as tp
|
||||||
|
import math
|
||||||
|
from torchaudio import transforms as T
|
||||||
|
|
||||||
|
from .utils import prepare_audio
|
||||||
|
from .sampling import sample, sample_k, sample_rf
|
||||||
|
from ..data.utils import PadCrop
|
||||||
|
|
||||||
|
def generate_diffusion_uncond(
|
||||||
|
model,
|
||||||
|
steps: int = 250,
|
||||||
|
batch_size: int = 1,
|
||||||
|
sample_size: int = 2097152,
|
||||||
|
seed: int = -1,
|
||||||
|
device: str = "cuda",
|
||||||
|
init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
|
||||||
|
init_noise_level: float = 1.0,
|
||||||
|
return_latents = False,
|
||||||
|
**sampler_kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# The length of the output in audio samples
|
||||||
|
audio_sample_size = sample_size
|
||||||
|
|
||||||
|
# If this is latent diffusion, change sample_size instead to the downsampled latent size
|
||||||
|
if model.pretransform is not None:
|
||||||
|
sample_size = sample_size // model.pretransform.downsampling_ratio
|
||||||
|
|
||||||
|
# Seed
|
||||||
|
# The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
|
||||||
|
seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
|
||||||
|
# seed = 777
|
||||||
|
print(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
# Define the initial noise immediately after setting the seed
|
||||||
|
noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
|
||||||
|
|
||||||
|
if init_audio is not None:
|
||||||
|
# The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
|
||||||
|
in_sr, init_audio = init_audio
|
||||||
|
|
||||||
|
io_channels = model.io_channels
|
||||||
|
|
||||||
|
# For latent models, set the io_channels to the autoencoder's io_channels
|
||||||
|
if model.pretransform is not None:
|
||||||
|
io_channels = model.pretransform.io_channels
|
||||||
|
|
||||||
|
# Prepare the initial audio for use by the model
|
||||||
|
init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
|
||||||
|
|
||||||
|
# For latent models, encode the initial audio into latents
|
||||||
|
if model.pretransform is not None:
|
||||||
|
init_audio = model.pretransform.encode(init_audio)
|
||||||
|
|
||||||
|
init_audio = init_audio.repeat(batch_size, 1, 1)
|
||||||
|
else:
|
||||||
|
# The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
|
||||||
|
init_audio = None
|
||||||
|
init_noise_level = None
|
||||||
|
|
||||||
|
# Inpainting mask
|
||||||
|
|
||||||
|
if init_audio is not None:
|
||||||
|
# variations
|
||||||
|
sampler_kwargs["sigma_max"] = init_noise_level
|
||||||
|
mask = None
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
# Now the generative AI part:
|
||||||
|
|
||||||
|
diff_objective = model.diffusion_objective
|
||||||
|
|
||||||
|
if diff_objective == "v":
|
||||||
|
# k-diffusion denoising process go!
|
||||||
|
sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
|
||||||
|
elif diff_objective == "rectified_flow":
|
||||||
|
sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device)
|
||||||
|
|
||||||
|
# Denoising process done.
|
||||||
|
# If this is latent diffusion, decode latents back into audio
|
||||||
|
if model.pretransform is not None and not return_latents:
|
||||||
|
sampled = model.pretransform.decode(sampled)
|
||||||
|
|
||||||
|
# Return audio
|
||||||
|
return sampled
|
||||||
|
|
||||||
|
|
||||||
|
def generate_diffusion_cond(
|
||||||
|
model,
|
||||||
|
steps: int = 250,
|
||||||
|
cfg_scale=6,
|
||||||
|
conditioning: dict = None,
|
||||||
|
conditioning_tensors: tp.Optional[dict] = None,
|
||||||
|
negative_conditioning: dict = None,
|
||||||
|
negative_conditioning_tensors: tp.Optional[dict] = None,
|
||||||
|
batch_size: int = 1,
|
||||||
|
sample_size: int = 2097152,
|
||||||
|
sample_rate: int = 48000,
|
||||||
|
seed: int = -1,
|
||||||
|
device: str = "cuda",
|
||||||
|
init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
|
||||||
|
init_noise_level: float = 1.0,
|
||||||
|
mask_args: dict = None,
|
||||||
|
return_latents = False,
|
||||||
|
**sampler_kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generate audio from a prompt using a diffusion model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The diffusion model to use for generation.
|
||||||
|
steps: The number of diffusion steps to use.
|
||||||
|
cfg_scale: Classifier-free guidance scale
|
||||||
|
conditioning: A dictionary of conditioning parameters to use for generation.
|
||||||
|
conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
|
||||||
|
batch_size: The batch size to use for generation.
|
||||||
|
sample_size: The length of the audio to generate, in samples.
|
||||||
|
sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
|
||||||
|
seed: The random seed to use for generation, or -1 to use a random seed.
|
||||||
|
device: The device to use for generation.
|
||||||
|
init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
|
||||||
|
init_noise_level: The noise level to use when generating from an initial audio sample.
|
||||||
|
return_latents: Whether to return the latents used for generation instead of the decoded audio.
|
||||||
|
**sampler_kwargs: Additional keyword arguments to pass to the sampler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The length of the output in audio samples
|
||||||
|
audio_sample_size = sample_size
|
||||||
|
|
||||||
|
# If this is latent diffusion, change sample_size instead to the downsampled latent size
|
||||||
|
if model.pretransform is not None:
|
||||||
|
sample_size = sample_size // model.pretransform.downsampling_ratio
|
||||||
|
|
||||||
|
# Seed
|
||||||
|
# The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
|
||||||
|
seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
|
||||||
|
# seed = 777
|
||||||
|
# print(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
# Define the initial noise immediately after setting the seed
|
||||||
|
noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
# Conditioning
|
||||||
|
assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
|
||||||
|
if conditioning_tensors is None:
|
||||||
|
conditioning_tensors = model.conditioner(conditioning, device)
|
||||||
|
conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
|
||||||
|
|
||||||
|
if negative_conditioning is not None or negative_conditioning_tensors is not None:
|
||||||
|
|
||||||
|
if negative_conditioning_tensors is None:
|
||||||
|
negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
|
||||||
|
|
||||||
|
negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
|
||||||
|
else:
|
||||||
|
negative_conditioning_tensors = {}
|
||||||
|
|
||||||
|
if init_audio is not None:
|
||||||
|
# The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
|
||||||
|
in_sr, init_audio = init_audio
|
||||||
|
|
||||||
|
io_channels = model.io_channels
|
||||||
|
|
||||||
|
# For latent models, set the io_channels to the autoencoder's io_channels
|
||||||
|
if model.pretransform is not None:
|
||||||
|
io_channels = model.pretransform.io_channels
|
||||||
|
|
||||||
|
# Prepare the initial audio for use by the model
|
||||||
|
init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
|
||||||
|
|
||||||
|
# For latent models, encode the initial audio into latents
|
||||||
|
if model.pretransform is not None:
|
||||||
|
init_audio = model.pretransform.encode(init_audio)
|
||||||
|
|
||||||
|
init_audio = init_audio.repeat(batch_size, 1, 1)
|
||||||
|
else:
|
||||||
|
# The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
|
||||||
|
init_audio = None
|
||||||
|
init_noise_level = None
|
||||||
|
mask_args = None
|
||||||
|
|
||||||
|
# Inpainting mask
|
||||||
|
if init_audio is not None and mask_args is not None:
|
||||||
|
# Cut and paste init_audio according to cropfrom, pastefrom, pasteto
|
||||||
|
# This is helpful for forward and reverse outpainting
|
||||||
|
cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
|
||||||
|
pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
|
||||||
|
pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
|
||||||
|
assert pastefrom < pasteto, "Paste From should be less than Paste To"
|
||||||
|
croplen = pasteto - pastefrom
|
||||||
|
if cropfrom + croplen > sample_size:
|
||||||
|
croplen = sample_size - cropfrom
|
||||||
|
cropto = cropfrom + croplen
|
||||||
|
pasteto = pastefrom + croplen
|
||||||
|
cutpaste = init_audio.new_zeros(init_audio.shape)
|
||||||
|
cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
|
||||||
|
#print(cropfrom, cropto, pastefrom, pasteto)
|
||||||
|
init_audio = cutpaste
|
||||||
|
# Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args
|
||||||
|
mask = build_mask(sample_size, mask_args)
|
||||||
|
mask = mask.to(device)
|
||||||
|
elif init_audio is not None and mask_args is None:
|
||||||
|
# variations
|
||||||
|
sampler_kwargs["sigma_max"] = init_noise_level
|
||||||
|
mask = None
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
model_dtype = next(model.model.parameters()).dtype
|
||||||
|
noise = noise.type(model_dtype)
|
||||||
|
conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()}
|
||||||
|
# Now the generative AI part:
|
||||||
|
# k-diffusion denoising process go!
|
||||||
|
|
||||||
|
diff_objective = model.diffusion_objective
|
||||||
|
|
||||||
|
if diff_objective == "v":
|
||||||
|
# k-diffusion denoising process go!
|
||||||
|
sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
|
||||||
|
|
||||||
|
elif diff_objective == "rectified_flow":
|
||||||
|
|
||||||
|
if "sigma_min" in sampler_kwargs:
|
||||||
|
del sampler_kwargs["sigma_min"]
|
||||||
|
|
||||||
|
if "sampler_type" in sampler_kwargs:
|
||||||
|
del sampler_kwargs["sampler_type"]
|
||||||
|
|
||||||
|
sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
|
||||||
|
|
||||||
|
# v-diffusion:
|
||||||
|
del noise
|
||||||
|
del conditioning_tensors
|
||||||
|
del conditioning_inputs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# Denoising process done.
|
||||||
|
# If this is latent diffusion, decode latents back into audio
|
||||||
|
|
||||||
|
if model.pretransform is not None and not return_latents:
|
||||||
|
#cast sampled latents to pretransform dtype
|
||||||
|
sampled = sampled.to(next(model.pretransform.parameters()).dtype)
|
||||||
|
sampled = model.pretransform.decode(sampled)
|
||||||
|
|
||||||
|
return sampled
|
||||||
|
|
||||||
|
# builds a softmask given the parameters
|
||||||
|
# returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio,
|
||||||
|
# and anything between is a mixture of old/new
|
||||||
|
# ideally 0.5 is half/half mixture but i haven't figured this out yet
|
||||||
|
def build_mask(sample_size, mask_args):
|
||||||
|
maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
|
||||||
|
maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
|
||||||
|
softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
|
||||||
|
softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
|
||||||
|
marination = mask_args["marination"]
|
||||||
|
# use hann windows for softening the transition (i don't know if this is correct)
|
||||||
|
hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
|
||||||
|
hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
|
||||||
|
# build the mask.
|
||||||
|
mask = torch.zeros((sample_size))
|
||||||
|
mask[maskstart:maskend] = 1
|
||||||
|
mask[maskstart:maskstart+softnessL] = hannL
|
||||||
|
mask[maskend-softnessR:maskend] = hannR
|
||||||
|
# marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds
|
||||||
|
if marination > 0:
|
||||||
|
mask = mask * (1-marination)
|
||||||
|
return mask
|
||||||
235
stable_audio_tools/inference/sampling.py
Normal file
235
stable_audio_tools/inference/sampling.py
Normal file
|
|
@ -0,0 +1,235 @@
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from tqdm import trange, tqdm
|
||||||
|
|
||||||
|
import k_diffusion as K
|
||||||
|
|
||||||
|
# Define the noise schedule and sampling loop
|
||||||
|
def get_alphas_sigmas(t):
|
||||||
|
"""Returns the scaling factors for the clean image (alpha) and for the
|
||||||
|
noise (sigma), given a timestep."""
|
||||||
|
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
||||||
|
|
||||||
|
def alpha_sigma_to_t(alpha, sigma):
|
||||||
|
"""Returns a timestep, given the scaling factors for the clean image and for
|
||||||
|
the noise."""
|
||||||
|
return torch.atan2(sigma, alpha) / math.pi * 2
|
||||||
|
|
||||||
|
def t_to_alpha_sigma(t):
|
||||||
|
"""Returns the scaling factors for the clean image and for the noise, given
|
||||||
|
a timestep."""
|
||||||
|
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
|
||||||
|
"""Draws samples from a model given starting noise. Euler method"""
|
||||||
|
|
||||||
|
# Make tensor of ones to broadcast the single t values
|
||||||
|
ts = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
# Create the noise schedule
|
||||||
|
t = torch.linspace(sigma_max, 0, steps + 1)
|
||||||
|
|
||||||
|
#alphas, sigmas = 1-t, t
|
||||||
|
|
||||||
|
for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
|
||||||
|
# Broadcast the current timestep to the correct shape
|
||||||
|
t_curr_tensor = t_curr * torch.ones(
|
||||||
|
(x.shape[0],), dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
dt = t_prev - t_curr # we solve backwards in our formulation
|
||||||
|
x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc)
|
||||||
|
|
||||||
|
# If we are on the last timestep, output the denoised image
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(model, x, steps, eta, **extra_args):
|
||||||
|
"""Draws samples from a model given starting noise. v-diffusion"""
|
||||||
|
ts = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
# Create the noise schedule
|
||||||
|
t = torch.linspace(1, 0, steps + 1)[:-1]
|
||||||
|
|
||||||
|
alphas, sigmas = get_alphas_sigmas(t)
|
||||||
|
|
||||||
|
# The sampling loop
|
||||||
|
for i in trange(steps):
|
||||||
|
|
||||||
|
# Get the model output (v, the predicted velocity)
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
v = model(x, ts * t[i], **extra_args).float()
|
||||||
|
|
||||||
|
# Predict the noise and the denoised image
|
||||||
|
pred = x * alphas[i] - v * sigmas[i]
|
||||||
|
eps = x * sigmas[i] + v * alphas[i]
|
||||||
|
|
||||||
|
# If we are not on the last timestep, compute the noisy image for the
|
||||||
|
# next timestep.
|
||||||
|
if i < steps - 1:
|
||||||
|
# If eta > 0, adjust the scaling factor for the predicted noise
|
||||||
|
# downward according to the amount of additional noise to add
|
||||||
|
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
|
||||||
|
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
|
||||||
|
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
|
||||||
|
|
||||||
|
# Recombine the predicted noise and predicted denoised image in the
|
||||||
|
# correct proportions for the next step
|
||||||
|
x = pred * alphas[i + 1] + eps * adjusted_sigma
|
||||||
|
|
||||||
|
# Add the correct amount of fresh noise
|
||||||
|
if eta:
|
||||||
|
x += torch.randn_like(x) * ddim_sigma
|
||||||
|
|
||||||
|
# If we are on the last timestep, output the denoised image
|
||||||
|
return pred
|
||||||
|
|
||||||
|
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
||||||
|
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
||||||
|
def get_bmask(i, steps, mask):
|
||||||
|
strength = (i+1)/(steps)
|
||||||
|
# convert to binary mask
|
||||||
|
bmask = torch.where(mask<=strength,1,0)
|
||||||
|
return bmask
|
||||||
|
|
||||||
|
def make_cond_model_fn(model, cond_fn):
|
||||||
|
def cond_model_fn(x, sigma, **kwargs):
|
||||||
|
with torch.enable_grad():
|
||||||
|
x = x.detach().requires_grad_()
|
||||||
|
denoised = model(x, sigma, **kwargs)
|
||||||
|
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
|
||||||
|
cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
|
||||||
|
return cond_denoised
|
||||||
|
return cond_model_fn
|
||||||
|
|
||||||
|
# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
|
||||||
|
# init_data is init_audio as latents (if this is latent diffusion)
|
||||||
|
# For sampling, set both init_data and mask to None
|
||||||
|
# For variations, set init_data
|
||||||
|
# For inpainting, set both init_data & mask
|
||||||
|
def sample_k(
|
||||||
|
model_fn,
|
||||||
|
noise,
|
||||||
|
init_data=None,
|
||||||
|
mask=None,
|
||||||
|
steps=100,
|
||||||
|
sampler_type="dpmpp-2m-sde",
|
||||||
|
sigma_min=0.5,
|
||||||
|
sigma_max=50,
|
||||||
|
rho=1.0, device="cuda",
|
||||||
|
callback=None,
|
||||||
|
cond_fn=None,
|
||||||
|
**extra_args
|
||||||
|
):
|
||||||
|
|
||||||
|
denoiser = K.external.VDenoiser(model_fn)
|
||||||
|
|
||||||
|
if cond_fn is not None:
|
||||||
|
denoiser = make_cond_model_fn(denoiser, cond_fn)
|
||||||
|
|
||||||
|
# Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
|
||||||
|
sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
|
||||||
|
# Scale the initial noise by sigma
|
||||||
|
noise = noise * sigmas[0]
|
||||||
|
|
||||||
|
wrapped_callback = callback
|
||||||
|
|
||||||
|
|
||||||
|
if mask is None and init_data is not None:
|
||||||
|
# VARIATION (no inpainting)
|
||||||
|
# set the initial latent to the init_data, and noise it with initial sigma
|
||||||
|
|
||||||
|
x = init_data + noise
|
||||||
|
|
||||||
|
elif mask is not None and init_data is not None:
|
||||||
|
# INPAINTING
|
||||||
|
bmask = get_bmask(0, steps, mask)
|
||||||
|
# initial noising
|
||||||
|
input_noised = init_data + noise
|
||||||
|
# set the initial latent to a mix of init_data and noise, based on step 0's binary mask
|
||||||
|
x = input_noised * bmask + noise * (1-bmask)
|
||||||
|
# define the inpainting callback function (Note: side effects, it mutates x)
|
||||||
|
# See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
|
||||||
|
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
# This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
|
||||||
|
def inpainting_callback(args):
|
||||||
|
i = args["i"]
|
||||||
|
x = args["x"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
#denoised = args["denoised"]
|
||||||
|
# noise the init_data input with this step's appropriate amount of noise
|
||||||
|
input_noised = init_data + torch.randn_like(init_data) * sigma
|
||||||
|
# shrinking hard mask
|
||||||
|
bmask = get_bmask(i, steps, mask)
|
||||||
|
# mix input_noise with x, using binary mask
|
||||||
|
new_x = input_noised * bmask + x * (1-bmask)
|
||||||
|
# mutate x
|
||||||
|
x[:,:,:] = new_x[:,:,:]
|
||||||
|
# wrap together the inpainting callback and the user-submitted callback.
|
||||||
|
if callback is None:
|
||||||
|
wrapped_callback = inpainting_callback
|
||||||
|
else:
|
||||||
|
wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
|
||||||
|
else:
|
||||||
|
# SAMPLING
|
||||||
|
# set the initial latent to noise
|
||||||
|
x = noise
|
||||||
|
# x = noise
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
if sampler_type == "k-heun":
|
||||||
|
return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
||||||
|
elif sampler_type == "k-lms":
|
||||||
|
return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
||||||
|
elif sampler_type == "k-dpmpp-2s-ancestral":
|
||||||
|
return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
||||||
|
elif sampler_type == "k-dpm-2":
|
||||||
|
return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
||||||
|
elif sampler_type == "k-dpm-fast":
|
||||||
|
return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
||||||
|
elif sampler_type == "k-dpm-adaptive":
|
||||||
|
return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
||||||
|
elif sampler_type == "dpmpp-2m-sde":
|
||||||
|
return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
||||||
|
elif sampler_type == "dpmpp-3m-sde":
|
||||||
|
return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
||||||
|
|
||||||
|
# Uses discrete Euler sampling for rectified flow models
|
||||||
|
# init_data is init_audio as latents (if this is latent diffusion)
|
||||||
|
# For sampling, set both init_data and mask to None
|
||||||
|
# For variations, set init_data
|
||||||
|
# For inpainting, set both init_data & mask
|
||||||
|
def sample_rf(
|
||||||
|
model_fn,
|
||||||
|
noise,
|
||||||
|
init_data=None,
|
||||||
|
steps=100,
|
||||||
|
sigma_max=1,
|
||||||
|
device="cuda",
|
||||||
|
callback=None,
|
||||||
|
cond_fn=None,
|
||||||
|
**extra_args
|
||||||
|
):
|
||||||
|
|
||||||
|
if sigma_max > 1:
|
||||||
|
sigma_max = 1
|
||||||
|
|
||||||
|
if cond_fn is not None:
|
||||||
|
denoiser = make_cond_model_fn(denoiser, cond_fn)
|
||||||
|
|
||||||
|
wrapped_callback = callback
|
||||||
|
|
||||||
|
if init_data is not None:
|
||||||
|
# VARIATION (no inpainting)
|
||||||
|
# Interpolate the init data and the noise for init audio
|
||||||
|
x = init_data * (1 - sigma_max) + noise * sigma_max
|
||||||
|
else:
|
||||||
|
# SAMPLING
|
||||||
|
# set the initial latent to noise
|
||||||
|
x = noise
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
# TODO: Add callback support
|
||||||
|
#return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
|
||||||
|
return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
|
||||||
35
stable_audio_tools/inference/utils.py
Normal file
35
stable_audio_tools/inference/utils.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
from ..data.utils import PadCrop
|
||||||
|
|
||||||
|
from torchaudio import transforms as T
|
||||||
|
|
||||||
|
def set_audio_channels(audio, target_channels):
|
||||||
|
if target_channels == 1:
|
||||||
|
# Convert to mono
|
||||||
|
audio = audio.mean(1, keepdim=True)
|
||||||
|
elif target_channels == 2:
|
||||||
|
# Convert to stereo
|
||||||
|
if audio.shape[1] == 1:
|
||||||
|
audio = audio.repeat(1, 2, 1)
|
||||||
|
elif audio.shape[1] > 2:
|
||||||
|
audio = audio[:, :2, :]
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||||
|
|
||||||
|
audio = audio.to(device)
|
||||||
|
|
||||||
|
if in_sr != target_sr:
|
||||||
|
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||||
|
audio = resample_tf(audio)
|
||||||
|
|
||||||
|
audio = PadCrop(target_length, randomize=False)(audio)
|
||||||
|
|
||||||
|
# Add batch dimension
|
||||||
|
if audio.dim() == 1:
|
||||||
|
audio = audio.unsqueeze(0).unsqueeze(0)
|
||||||
|
elif audio.dim() == 2:
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
|
||||||
|
audio = set_audio_channels(audio, target_channels)
|
||||||
|
|
||||||
|
return audio
|
||||||
0
stable_audio_tools/interface/__init__.py
Normal file
0
stable_audio_tools/interface/__init__.py
Normal file
495
stable_audio_tools/interface/gradio.py
Normal file
495
stable_audio_tools/interface/gradio.py
Normal file
|
|
@ -0,0 +1,495 @@
|
||||||
|
import gc
|
||||||
|
import platform
|
||||||
|
import os
|
||||||
|
import subprocess as sp
|
||||||
|
import gradio as gr
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
from aeiou.viz import audio_spectrogram_image
|
||||||
|
from einops import rearrange
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torchaudio import transforms as T
|
||||||
|
|
||||||
|
from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
|
||||||
|
from ..models.factory import create_model_from_config
|
||||||
|
from ..models.pretrained import get_pretrained_model
|
||||||
|
from ..models.utils import load_ckpt_state_dict
|
||||||
|
from ..inference.utils import prepare_audio
|
||||||
|
from ..training.utils import copy_state_dict
|
||||||
|
from ..data.utils import read_video, merge_video_audio
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
os.environ['TMPDIR'] = './tmp'
|
||||||
|
|
||||||
|
current_model_name = None
|
||||||
|
current_model = None
|
||||||
|
current_sample_rate = None
|
||||||
|
current_sample_size = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_name, model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
|
||||||
|
global model_configurations
|
||||||
|
|
||||||
|
if pretrained_name is not None:
|
||||||
|
print(f"Loading pretrained model {pretrained_name}")
|
||||||
|
model, model_config = get_pretrained_model(pretrained_name)
|
||||||
|
elif model_config is not None and model_ckpt_path is not None:
|
||||||
|
print(f"Creating model from config")
|
||||||
|
model = create_model_from_config(model_config)
|
||||||
|
print(f"Loading model checkpoint from {model_ckpt_path}")
|
||||||
|
copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
|
||||||
|
sample_rate = model_config["sample_rate"]
|
||||||
|
sample_size = model_config["sample_size"]
|
||||||
|
if pretransform_ckpt_path is not None:
|
||||||
|
print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
|
||||||
|
model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
|
||||||
|
print(f"Done loading pretransform")
|
||||||
|
model.to(device).eval().requires_grad_(False)
|
||||||
|
if model_half:
|
||||||
|
model.to(torch.float16)
|
||||||
|
print(f"Done loading model")
|
||||||
|
return model, model_config, sample_rate, sample_size
|
||||||
|
|
||||||
|
def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total):
|
||||||
|
if audio_path is None:
|
||||||
|
return torch.zeros((2, int(sample_rate * seconds_total)))
|
||||||
|
audio_tensor, sr = torchaudio.load(audio_path)
|
||||||
|
start_index = int(sample_rate * seconds_start)
|
||||||
|
target_length = int(sample_rate * seconds_total)
|
||||||
|
end_index = start_index + target_length
|
||||||
|
audio_tensor = audio_tensor[:, start_index:end_index]
|
||||||
|
if audio_tensor.shape[1] < target_length:
|
||||||
|
pad_length = target_length - audio_tensor.shape[1]
|
||||||
|
audio_tensor = F.pad(audio_tensor, (pad_length, 0))
|
||||||
|
return audio_tensor
|
||||||
|
|
||||||
|
def generate_cond(
|
||||||
|
prompt,
|
||||||
|
negative_prompt=None,
|
||||||
|
video_file=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_prompt_file=None,
|
||||||
|
audio_prompt_path=None,
|
||||||
|
seconds_start=0,
|
||||||
|
seconds_total=10,
|
||||||
|
cfg_scale=6.0,
|
||||||
|
steps=250,
|
||||||
|
preview_every=None,
|
||||||
|
seed=-1,
|
||||||
|
sampler_type="dpmpp-3m-sde",
|
||||||
|
sigma_min=0.03,
|
||||||
|
sigma_max=1000,
|
||||||
|
cfg_rescale=0.0,
|
||||||
|
use_init=False,
|
||||||
|
init_audio=None,
|
||||||
|
init_noise_level=1.0,
|
||||||
|
mask_cropfrom=None,
|
||||||
|
mask_pastefrom=None,
|
||||||
|
mask_pasteto=None,
|
||||||
|
mask_maskstart=None,
|
||||||
|
mask_maskend=None,
|
||||||
|
mask_softnessL=None,
|
||||||
|
mask_softnessR=None,
|
||||||
|
mask_marination=None,
|
||||||
|
batch_size=1
|
||||||
|
):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
print(f"Prompt: {prompt}")
|
||||||
|
preview_images = []
|
||||||
|
if preview_every == 0:
|
||||||
|
preview_every = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
|
||||||
|
except Exception:
|
||||||
|
has_mps = False
|
||||||
|
if has_mps:
|
||||||
|
device = torch.device("mps")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
model_name = 'default'
|
||||||
|
cfg = model_configurations[model_name]
|
||||||
|
model_config_path = cfg.get("model_config")
|
||||||
|
ckpt_path = cfg.get("ckpt_path")
|
||||||
|
pretrained_name = cfg.get("pretrained_name")
|
||||||
|
pretransform_ckpt_path = cfg.get("pretransform_ckpt_path")
|
||||||
|
model_type = cfg.get("model_type", "diffusion_cond")
|
||||||
|
if model_config_path:
|
||||||
|
with open(model_config_path) as f:
|
||||||
|
model_config = json.load(f)
|
||||||
|
else:
|
||||||
|
model_config = None
|
||||||
|
target_fps = model_config.get("video_fps", 5)
|
||||||
|
global current_model_name, current_model, current_sample_rate, current_sample_size
|
||||||
|
if current_model is None or model_name != current_model_name:
|
||||||
|
current_model, model_config, sample_rate, sample_size = load_model(
|
||||||
|
model_name=model_name,
|
||||||
|
model_config=model_config,
|
||||||
|
model_ckpt_path=ckpt_path,
|
||||||
|
pretrained_name=pretrained_name,
|
||||||
|
pretransform_ckpt_path=pretransform_ckpt_path,
|
||||||
|
device=device,
|
||||||
|
model_half=False
|
||||||
|
)
|
||||||
|
current_model_name = model_name
|
||||||
|
model = current_model
|
||||||
|
current_sample_rate = sample_rate
|
||||||
|
current_sample_size = sample_size
|
||||||
|
else:
|
||||||
|
model = current_model
|
||||||
|
sample_rate = current_sample_rate
|
||||||
|
sample_size = current_sample_size
|
||||||
|
if video_file is not None:
|
||||||
|
video_path = video_file.name
|
||||||
|
elif video_path:
|
||||||
|
video_path = video_path.strip()
|
||||||
|
else:
|
||||||
|
video_path = None
|
||||||
|
|
||||||
|
if audio_prompt_file is not None:
|
||||||
|
print(f'audio_prompt_file: {audio_prompt_file}')
|
||||||
|
audio_path = audio_prompt_file.name
|
||||||
|
elif audio_prompt_path:
|
||||||
|
audio_path = audio_prompt_path.strip()
|
||||||
|
else:
|
||||||
|
audio_path = None
|
||||||
|
|
||||||
|
Video_tensors = read_video(video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps)
|
||||||
|
audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total)
|
||||||
|
|
||||||
|
audio_tensor = audio_tensor.to(device)
|
||||||
|
seconds_input = sample_size / sample_rate
|
||||||
|
print(f'video_path: {video_path}')
|
||||||
|
|
||||||
|
if not prompt:
|
||||||
|
prompt = ""
|
||||||
|
|
||||||
|
conditioning = [{
|
||||||
|
"video_prompt": [Video_tensors.unsqueeze(0)],
|
||||||
|
"text_prompt": prompt,
|
||||||
|
"audio_prompt": audio_tensor.unsqueeze(0),
|
||||||
|
"seconds_start": seconds_start,
|
||||||
|
"seconds_total": seconds_input
|
||||||
|
}] * batch_size
|
||||||
|
if negative_prompt:
|
||||||
|
negative_conditioning = [{
|
||||||
|
"video_prompt": [Video_tensors.unsqueeze(0)],
|
||||||
|
"text_prompt": negative_prompt,
|
||||||
|
"audio_prompt": audio_tensor.unsqueeze(0),
|
||||||
|
"seconds_start": seconds_start,
|
||||||
|
"seconds_total": seconds_total
|
||||||
|
}] * batch_size
|
||||||
|
else:
|
||||||
|
negative_conditioning = None
|
||||||
|
try:
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
except Exception as e:
|
||||||
|
device = next(current_model.parameters()).device
|
||||||
|
seed = int(seed)
|
||||||
|
if not use_init:
|
||||||
|
init_audio = None
|
||||||
|
input_sample_size = sample_size
|
||||||
|
if init_audio is not None:
|
||||||
|
in_sr, init_audio = init_audio
|
||||||
|
init_audio = torch.from_numpy(init_audio).float().div(32767)
|
||||||
|
if init_audio.dim() == 1:
|
||||||
|
init_audio = init_audio.unsqueeze(0)
|
||||||
|
elif init_audio.dim() == 2:
|
||||||
|
init_audio = init_audio.transpose(0, 1)
|
||||||
|
if in_sr != sample_rate:
|
||||||
|
resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
|
||||||
|
init_audio = resample_tf(init_audio)
|
||||||
|
audio_length = init_audio.shape[-1]
|
||||||
|
if audio_length > sample_size:
|
||||||
|
input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
|
||||||
|
init_audio = (sample_rate, init_audio)
|
||||||
|
def progress_callback(callback_info):
|
||||||
|
nonlocal preview_images
|
||||||
|
denoised = callback_info["denoised"]
|
||||||
|
current_step = callback_info["i"]
|
||||||
|
sigma = callback_info["sigma"]
|
||||||
|
if (current_step - 1) % preview_every == 0:
|
||||||
|
if model.pretransform is not None:
|
||||||
|
denoised = model.pretransform.decode(denoised)
|
||||||
|
denoised = rearrange(denoised, "b d n -> d (b n)")
|
||||||
|
denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
||||||
|
audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
|
||||||
|
preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
|
||||||
|
if mask_cropfrom is not None:
|
||||||
|
mask_args = {
|
||||||
|
"cropfrom": mask_cropfrom,
|
||||||
|
"pastefrom": mask_pastefrom,
|
||||||
|
"pasteto": mask_pasteto,
|
||||||
|
"maskstart": mask_maskstart,
|
||||||
|
"maskend": mask_maskend,
|
||||||
|
"softnessL": mask_softnessL,
|
||||||
|
"softnessR": mask_softnessR,
|
||||||
|
"marination": mask_marination,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
mask_args = None
|
||||||
|
if model_type == "diffusion_cond":
|
||||||
|
audio = generate_diffusion_cond(
|
||||||
|
model,
|
||||||
|
conditioning=conditioning,
|
||||||
|
negative_conditioning=negative_conditioning,
|
||||||
|
steps=steps,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sample_size=input_sample_size,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
seed=seed,
|
||||||
|
device=device,
|
||||||
|
sampler_type=sampler_type,
|
||||||
|
sigma_min=sigma_min,
|
||||||
|
sigma_max=sigma_max,
|
||||||
|
init_audio=init_audio,
|
||||||
|
init_noise_level=init_noise_level,
|
||||||
|
mask_args=mask_args,
|
||||||
|
callback=progress_callback if preview_every is not None else None,
|
||||||
|
scale_phi=cfg_rescale
|
||||||
|
)
|
||||||
|
elif model_type == "diffusion_uncond":
|
||||||
|
audio = generate_diffusion_uncond(
|
||||||
|
model,
|
||||||
|
steps=steps,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sample_size=input_sample_size,
|
||||||
|
seed=seed,
|
||||||
|
device=device,
|
||||||
|
sampler_type=sampler_type,
|
||||||
|
sigma_min=sigma_min,
|
||||||
|
sigma_max=sigma_max,
|
||||||
|
init_audio=init_audio,
|
||||||
|
init_noise_level=init_noise_level,
|
||||||
|
callback=progress_callback if preview_every is not None else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model type: {model_type}")
|
||||||
|
audio = rearrange(audio, "b d n -> d (b n)")
|
||||||
|
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
||||||
|
file_name = os.path.basename(video_path) if video_path else "output"
|
||||||
|
output_dir = f"demo_result"
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
output_video_path = f"{output_dir}/{file_name}"
|
||||||
|
torchaudio.save(f"{output_dir}/output.wav", audio, sample_rate)
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
if video_path:
|
||||||
|
merge_video_audio(video_path, f"{output_dir}/output.wav", output_video_path, seconds_start, seconds_total)
|
||||||
|
audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
|
||||||
|
del video_path
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
return (output_video_path, f"{output_dir}/output.wav")
|
||||||
|
|
||||||
|
def toggle_custom_model(selected_model):
|
||||||
|
return gr.Row.update(visible=(selected_model == "Custom Model"))
|
||||||
|
|
||||||
|
def create_sampling_ui(model_config_map, inpainting=False):
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
# 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation
|
||||||
|
**[Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/Zeyue7/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)**
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Tab("Generation"):
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt")
|
||||||
|
negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt", visible=False)
|
||||||
|
video_path = gr.Textbox(label="Video Path", placeholder="Enter video file path")
|
||||||
|
video_file = gr.File(label="Upload Video File")
|
||||||
|
audio_prompt_file = gr.File(label="Upload Audio Prompt File", visible=False)
|
||||||
|
audio_prompt_path = gr.Textbox(label="Audio Prompt Path", placeholder="Enter audio file path", visible=False)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=6):
|
||||||
|
with gr.Accordion("Video Params", open=False):
|
||||||
|
seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Video Seconds Start")
|
||||||
|
seconds_total_slider = gr.Slider(minimum=0, maximum=10, step=1, value=10, label="Seconds Total", interactive=False)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=4):
|
||||||
|
with gr.Accordion("Sampler Params", open=False):
|
||||||
|
steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
|
||||||
|
preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
|
||||||
|
cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG Scale")
|
||||||
|
seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
|
||||||
|
sampler_type_dropdown = gr.Dropdown(
|
||||||
|
["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"],
|
||||||
|
label="Sampler Type",
|
||||||
|
value="dpmpp-3m-sde"
|
||||||
|
)
|
||||||
|
sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma Min")
|
||||||
|
sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma Max")
|
||||||
|
cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG Rescale Amount")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=4):
|
||||||
|
with gr.Accordion("Init Audio", open=False, visible=False):
|
||||||
|
init_audio_checkbox = gr.Checkbox(label="Use Init Audio")
|
||||||
|
init_audio_input = gr.Audio(label="Init Audio")
|
||||||
|
init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init Noise Level")
|
||||||
|
gr.Markdown("## Examples")
|
||||||
|
with gr.Accordion("Click to show examples", open=False):
|
||||||
|
with gr.Row():
|
||||||
|
gr.Markdown("**📝 Task: Text-to-Audio**")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Markdown("Prompt: *Typing on a keyboard*")
|
||||||
|
ex1 = gr.Button("Load Example")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Markdown("Prompt: *Ocean waves crashing*")
|
||||||
|
ex2 = gr.Button("Load Example")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Markdown("Prompt: *Footsteps in snow*")
|
||||||
|
ex3 = gr.Button("Load Example")
|
||||||
|
with gr.Row():
|
||||||
|
gr.Markdown("**🎶 Task: Text-to-Music**")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*")
|
||||||
|
ex4 = gr.Button("Load Example")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*")
|
||||||
|
ex5 = gr.Button("Load Example")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*")
|
||||||
|
ex6 = gr.Button("Load Example")
|
||||||
|
with gr.Row():
|
||||||
|
gr.Markdown("**🎬 Task: Video-to-Audio**\nPrompt: *Generate general audio for the video*")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Video("example/V2A_sample-1.mp4")
|
||||||
|
ex7 = gr.Button("Load Example")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Video("example/V2A_sample-2.mp4")
|
||||||
|
ex8 = gr.Button("Load Example")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Video("example/V2A_sample-3.mp4")
|
||||||
|
ex9 = gr.Button("Load Example")
|
||||||
|
with gr.Row():
|
||||||
|
gr.Markdown("**🎵 Task: Video-to-Music**\nPrompt: *Generate music for the video*")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Video("example/V2M_sample-1.mp4")
|
||||||
|
ex10 = gr.Button("Load Example")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Video("example/V2M_sample-2.mp4")
|
||||||
|
ex11 = gr.Button("Load Example")
|
||||||
|
with gr.Column(scale=1.2):
|
||||||
|
gr.Video("example/V2M_sample-3.mp4")
|
||||||
|
ex12 = gr.Button("Load Example")
|
||||||
|
with gr.Row():
|
||||||
|
generate_button = gr.Button("Generate", variant='primary', scale=1)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=6):
|
||||||
|
video_output = gr.Video(label="Output Video", interactive=False)
|
||||||
|
audio_output = gr.Audio(label="Output Audio", interactive=False)
|
||||||
|
send_to_init_button = gr.Button("Send to Init Audio", scale=1, visible=False)
|
||||||
|
send_to_init_button.click(
|
||||||
|
fn=lambda audio: audio,
|
||||||
|
inputs=[audio_output],
|
||||||
|
outputs=[init_audio_input]
|
||||||
|
)
|
||||||
|
inputs = [
|
||||||
|
prompt,
|
||||||
|
negative_prompt,
|
||||||
|
video_file,
|
||||||
|
video_path,
|
||||||
|
audio_prompt_file,
|
||||||
|
audio_prompt_path,
|
||||||
|
seconds_start_slider,
|
||||||
|
seconds_total_slider,
|
||||||
|
cfg_scale_slider,
|
||||||
|
steps_slider,
|
||||||
|
preview_every_slider,
|
||||||
|
seed_textbox,
|
||||||
|
sampler_type_dropdown,
|
||||||
|
sigma_min_slider,
|
||||||
|
sigma_max_slider,
|
||||||
|
cfg_rescale_slider,
|
||||||
|
init_audio_checkbox,
|
||||||
|
init_audio_input,
|
||||||
|
init_noise_level_slider
|
||||||
|
]
|
||||||
|
generate_button.click(
|
||||||
|
fn=generate_cond,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=[
|
||||||
|
video_output,
|
||||||
|
audio_output
|
||||||
|
],
|
||||||
|
api_name="generate"
|
||||||
|
)
|
||||||
|
ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex3.click(lambda: ["Footsteps in snow", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex7.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3737819478", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex8.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "1900718499", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex9.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "2289822202", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex10.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3498087420", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex11.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "3753837734", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
ex12.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "3510832996", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs)
|
||||||
|
return demo
|
||||||
|
|
||||||
|
def create_txt2audio_ui(model_config_map):
|
||||||
|
with gr.Blocks(css=".gradio-container { max-width: 1120px; margin: auto; }") as ui:
|
||||||
|
with gr.Tab("Generation"):
|
||||||
|
create_sampling_ui(model_config_map)
|
||||||
|
return ui
|
||||||
|
|
||||||
|
def toggle_custom_model(selected_model):
|
||||||
|
return gr.Row.update(visible=(selected_model == "Custom Model"))
|
||||||
|
|
||||||
|
def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
|
||||||
|
global model_configurations
|
||||||
|
global device
|
||||||
|
|
||||||
|
try:
|
||||||
|
has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
|
||||||
|
except Exception:
|
||||||
|
has_mps = False
|
||||||
|
|
||||||
|
if has_mps:
|
||||||
|
device = torch.device("mps")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
print("Using device:", device)
|
||||||
|
|
||||||
|
model_configurations = {
|
||||||
|
"default": {
|
||||||
|
"model_config": "./model/config.json",
|
||||||
|
"ckpt_path": "./model/model.ckpt"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ui = create_txt2audio_ui(model_configurations)
|
||||||
|
return ui
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ui = create_ui(
|
||||||
|
model_config_path='./model/config.json',
|
||||||
|
share=True
|
||||||
|
)
|
||||||
|
ui.launch()
|
||||||
1
stable_audio_tools/models/__init__.py
Normal file
1
stable_audio_tools/models/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .factory import create_model_from_config, create_model_from_config_path
|
||||||
1588
stable_audio_tools/models/adp.py
Normal file
1588
stable_audio_tools/models/adp.py
Normal file
File diff suppressed because it is too large
Load diff
794
stable_audio_tools/models/autoencoders.py
Normal file
794
stable_audio_tools/models/autoencoders.py
Normal file
|
|
@ -0,0 +1,794 @@
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torchaudio import transforms as T
|
||||||
|
from alias_free_torch import Activation1d
|
||||||
|
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
||||||
|
from typing import Literal, Dict, Any
|
||||||
|
|
||||||
|
from ..inference.sampling import sample
|
||||||
|
from ..inference.utils import prepare_audio
|
||||||
|
from .blocks import SnakeBeta
|
||||||
|
from .bottleneck import Bottleneck, DiscreteBottleneck
|
||||||
|
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
||||||
|
from .factory import create_pretransform_from_config, create_bottleneck_from_config
|
||||||
|
from .pretransforms import Pretransform
|
||||||
|
|
||||||
|
def checkpoint(function, *args, **kwargs):
|
||||||
|
kwargs.setdefault("use_reentrant", False)
|
||||||
|
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||||
|
|
||||||
|
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||||
|
if activation == "elu":
|
||||||
|
act = nn.ELU()
|
||||||
|
elif activation == "snake":
|
||||||
|
act = SnakeBeta(channels)
|
||||||
|
elif activation == "none":
|
||||||
|
act = nn.Identity()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation {activation}")
|
||||||
|
|
||||||
|
if antialias:
|
||||||
|
act = Activation1d(act)
|
||||||
|
|
||||||
|
return act
|
||||||
|
|
||||||
|
class ResidualUnit(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
padding = (dilation * (7-1)) // 2
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=7, dilation=dilation, padding=padding),
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||||
|
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
kernel_size=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = x
|
||||||
|
|
||||||
|
#x = checkpoint(self.layers, x)
|
||||||
|
x = self.layers(x)
|
||||||
|
|
||||||
|
return x + res
|
||||||
|
|
||||||
|
class EncoderBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if use_nearest_upsample:
|
||||||
|
upsample_layer = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||||
|
WNConv1d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2*stride,
|
||||||
|
stride=1,
|
||||||
|
bias=False,
|
||||||
|
padding='same')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||||
|
upsample_layer,
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=1, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=3, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=9, use_snake=use_snake),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
class OobleckEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=32,
|
||||||
|
c_mults = [1, 2, 4, 8],
|
||||||
|
strides = [2, 4, 8, 8],
|
||||||
|
use_snake=False,
|
||||||
|
antialias_activation=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
c_mults = [1] + c_mults
|
||||||
|
|
||||||
|
self.depth = len(c_mults)
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(self.depth-1):
|
||||||
|
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||||
|
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckDecoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
out_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=32,
|
||||||
|
c_mults = [1, 2, 4, 8],
|
||||||
|
strides = [2, 4, 8, 8],
|
||||||
|
use_snake=False,
|
||||||
|
antialias_activation=False,
|
||||||
|
use_nearest_upsample=False,
|
||||||
|
final_tanh=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
c_mults = [1] + c_mults
|
||||||
|
|
||||||
|
self.depth = len(c_mults)
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(self.depth-1, 0, -1):
|
||||||
|
layers += [DecoderBlock(
|
||||||
|
in_channels=c_mults[i]*channels,
|
||||||
|
out_channels=c_mults[i-1]*channels,
|
||||||
|
stride=strides[i-1],
|
||||||
|
use_snake=use_snake,
|
||||||
|
antialias_activation=antialias_activation,
|
||||||
|
use_nearest_upsample=use_nearest_upsample
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||||
|
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||||
|
nn.Tanh() if final_tanh else nn.Identity()
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DACEncoderWrapper(nn.Module):
|
||||||
|
def __init__(self, in_channels=1, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
from dac.model.dac import Encoder as DACEncoder
|
||||||
|
|
||||||
|
latent_dim = kwargs.pop("latent_dim", None)
|
||||||
|
|
||||||
|
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
|
||||||
|
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
|
||||||
|
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
|
||||||
|
|
||||||
|
if in_channels != 1:
|
||||||
|
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.encoder(x)
|
||||||
|
x = self.proj_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class DACDecoderWrapper(nn.Module):
|
||||||
|
def __init__(self, latent_dim, out_channels=1, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
from dac.model.dac import Decoder as DACDecoder
|
||||||
|
|
||||||
|
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
|
||||||
|
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.decoder(x)
|
||||||
|
|
||||||
|
class AudioAutoencoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
latent_dim,
|
||||||
|
downsampling_ratio,
|
||||||
|
sample_rate,
|
||||||
|
io_channels=2,
|
||||||
|
bottleneck: Bottleneck = None,
|
||||||
|
pretransform: Pretransform = None,
|
||||||
|
in_channels = None,
|
||||||
|
out_channels = None,
|
||||||
|
soft_clip = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.downsampling_ratio = downsampling_ratio
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.io_channels = io_channels
|
||||||
|
self.in_channels = io_channels
|
||||||
|
self.out_channels = io_channels
|
||||||
|
|
||||||
|
self.min_length = self.downsampling_ratio
|
||||||
|
|
||||||
|
if in_channels is not None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
if out_channels is not None:
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.bottleneck = bottleneck
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
self.pretransform = pretransform
|
||||||
|
|
||||||
|
self.soft_clip = soft_clip
|
||||||
|
|
||||||
|
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
||||||
|
|
||||||
|
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
if self.pretransform is not None and not skip_pretransform:
|
||||||
|
if self.pretransform.enable_grad:
|
||||||
|
if iterate_batch:
|
||||||
|
audios = []
|
||||||
|
for i in range(audio.shape[0]):
|
||||||
|
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||||
|
audio = torch.cat(audios, dim=0)
|
||||||
|
else:
|
||||||
|
audio = self.pretransform.encode(audio)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
if iterate_batch:
|
||||||
|
audios = []
|
||||||
|
for i in range(audio.shape[0]):
|
||||||
|
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||||
|
audio = torch.cat(audios, dim=0)
|
||||||
|
else:
|
||||||
|
audio = self.pretransform.encode(audio)
|
||||||
|
|
||||||
|
if self.encoder is not None:
|
||||||
|
if iterate_batch:
|
||||||
|
latents = []
|
||||||
|
for i in range(audio.shape[0]):
|
||||||
|
latents.append(self.encoder(audio[i:i+1]))
|
||||||
|
latents = torch.cat(latents, dim=0)
|
||||||
|
else:
|
||||||
|
latents = self.encoder(audio)
|
||||||
|
else:
|
||||||
|
latents = audio
|
||||||
|
|
||||||
|
if self.bottleneck is not None:
|
||||||
|
# TODO: Add iterate batch logic, needs to merge the info dicts
|
||||||
|
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
||||||
|
|
||||||
|
info.update(bottleneck_info)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return latents, info
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def decode(self, latents, iterate_batch=False, **kwargs):
|
||||||
|
|
||||||
|
if self.bottleneck is not None:
|
||||||
|
if iterate_batch:
|
||||||
|
decoded = []
|
||||||
|
for i in range(latents.shape[0]):
|
||||||
|
decoded.append(self.bottleneck.decode(latents[i:i+1]))
|
||||||
|
latents = torch.cat(decoded, dim=0)
|
||||||
|
else:
|
||||||
|
latents = self.bottleneck.decode(latents)
|
||||||
|
|
||||||
|
if iterate_batch:
|
||||||
|
decoded = []
|
||||||
|
for i in range(latents.shape[0]):
|
||||||
|
decoded.append(self.decoder(latents[i:i+1]))
|
||||||
|
decoded = torch.cat(decoded, dim=0)
|
||||||
|
else:
|
||||||
|
decoded = self.decoder(latents, **kwargs)
|
||||||
|
|
||||||
|
if self.pretransform is not None:
|
||||||
|
if self.pretransform.enable_grad:
|
||||||
|
if iterate_batch:
|
||||||
|
decodeds = []
|
||||||
|
for i in range(decoded.shape[0]):
|
||||||
|
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||||
|
decoded = torch.cat(decodeds, dim=0)
|
||||||
|
else:
|
||||||
|
decoded = self.pretransform.decode(decoded)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
if iterate_batch:
|
||||||
|
decodeds = []
|
||||||
|
for i in range(latents.shape[0]):
|
||||||
|
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||||
|
decoded = torch.cat(decodeds, dim=0)
|
||||||
|
else:
|
||||||
|
decoded = self.pretransform.decode(decoded)
|
||||||
|
|
||||||
|
if self.soft_clip:
|
||||||
|
decoded = torch.tanh(decoded)
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens, **kwargs):
|
||||||
|
'''
|
||||||
|
Decode discrete tokens to audio
|
||||||
|
Only works with discrete autoencoders
|
||||||
|
'''
|
||||||
|
|
||||||
|
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
|
||||||
|
|
||||||
|
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_audio_for_encoder(self, audio, in_sr):
|
||||||
|
'''
|
||||||
|
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
|
||||||
|
If the model is mono, stereo audio will be converted to mono.
|
||||||
|
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
|
||||||
|
Audio will be resampled to the model's sample rate.
|
||||||
|
The output will have batch size 1 and be shape (1 x Channels x Length)
|
||||||
|
'''
|
||||||
|
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
|
||||||
|
|
||||||
|
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
|
||||||
|
'''
|
||||||
|
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
|
||||||
|
The audio in that list can be of different lengths and channels.
|
||||||
|
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
|
||||||
|
All audio will be resampled to the model's sample rate.
|
||||||
|
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
|
||||||
|
If the model is mono, all audio will be converted to mono.
|
||||||
|
The output will be a tensor of shape (Batch x Channels x Length)
|
||||||
|
'''
|
||||||
|
batch_size = len(audio_list)
|
||||||
|
if isinstance(in_sr_list, int):
|
||||||
|
in_sr_list = [in_sr_list]*batch_size
|
||||||
|
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
|
||||||
|
new_audio = []
|
||||||
|
max_length = 0
|
||||||
|
# resample & find the max length
|
||||||
|
for i in range(batch_size):
|
||||||
|
audio = audio_list[i]
|
||||||
|
in_sr = in_sr_list[i]
|
||||||
|
if len(audio.shape) == 3 and audio.shape[0] == 1:
|
||||||
|
# batchsize 1 was given by accident. Just squeeze it.
|
||||||
|
audio = audio.squeeze(0)
|
||||||
|
elif len(audio.shape) == 1:
|
||||||
|
# Mono signal, channel dimension is missing, unsqueeze it in
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
|
||||||
|
# Resample audio
|
||||||
|
if in_sr != self.sample_rate:
|
||||||
|
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
|
||||||
|
audio = resample_tf(audio)
|
||||||
|
new_audio.append(audio)
|
||||||
|
if audio.shape[-1] > max_length:
|
||||||
|
max_length = audio.shape[-1]
|
||||||
|
# Pad every audio to the same length, multiple of model's downsampling ratio
|
||||||
|
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
|
||||||
|
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
|
||||||
|
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
|
||||||
|
# convert to tensor
|
||||||
|
return torch.stack(new_audio)
|
||||||
|
|
||||||
|
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||||
|
'''
|
||||||
|
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
||||||
|
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
||||||
|
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
||||||
|
# and therefore you likely could use the same values with decode_audio.
|
||||||
|
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||||
|
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||||
|
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
||||||
|
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||||
|
Smaller chunk_size uses less memory, but more compute.
|
||||||
|
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||||
|
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||||
|
'''
|
||||||
|
if not chunked:
|
||||||
|
# default behavior. Encode the entire audio in parallel
|
||||||
|
return self.encode(audio, **kwargs)
|
||||||
|
else:
|
||||||
|
# CHUNKED ENCODING
|
||||||
|
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
||||||
|
samples_per_latent = self.downsampling_ratio
|
||||||
|
total_size = audio.shape[2] # in samples
|
||||||
|
batch_size = audio.shape[0]
|
||||||
|
chunk_size *= samples_per_latent # converting metric in latents to samples
|
||||||
|
overlap *= samples_per_latent # converting metric in latents to samples
|
||||||
|
hop_size = chunk_size - overlap
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||||
|
chunk = audio[:,:,i:i+chunk_size]
|
||||||
|
chunks.append(chunk)
|
||||||
|
if i+chunk_size != total_size:
|
||||||
|
# Final chunk
|
||||||
|
chunk = audio[:,:,-chunk_size:]
|
||||||
|
chunks.append(chunk)
|
||||||
|
chunks = torch.stack(chunks)
|
||||||
|
num_chunks = chunks.shape[0]
|
||||||
|
# Note: y_size might be a different value from the latent length used in diffusion training
|
||||||
|
# because we can encode audio of varying lengths
|
||||||
|
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
||||||
|
y_size = total_size // samples_per_latent
|
||||||
|
# Create an empty latent, we will populate it with chunks as we encode them
|
||||||
|
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
||||||
|
for i in range(num_chunks):
|
||||||
|
x_chunk = chunks[i,:]
|
||||||
|
# encode the chunk
|
||||||
|
y_chunk = self.encode(x_chunk)
|
||||||
|
# figure out where to put the audio along the time domain
|
||||||
|
if i == num_chunks-1:
|
||||||
|
# final chunk always goes at the end
|
||||||
|
t_end = y_size
|
||||||
|
t_start = t_end - y_chunk.shape[2]
|
||||||
|
else:
|
||||||
|
t_start = i * hop_size // samples_per_latent
|
||||||
|
t_end = t_start + chunk_size // samples_per_latent
|
||||||
|
# remove the edges of the overlaps
|
||||||
|
ol = overlap//samples_per_latent//2
|
||||||
|
chunk_start = 0
|
||||||
|
chunk_end = y_chunk.shape[2]
|
||||||
|
if i > 0:
|
||||||
|
# no overlap for the start of the first chunk
|
||||||
|
t_start += ol
|
||||||
|
chunk_start += ol
|
||||||
|
if i < num_chunks-1:
|
||||||
|
# no overlap for the end of the last chunk
|
||||||
|
t_end -= ol
|
||||||
|
chunk_end -= ol
|
||||||
|
# paste the chunked audio into our y_final output audio
|
||||||
|
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||||
|
return y_final
|
||||||
|
|
||||||
|
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||||
|
'''
|
||||||
|
Decode latents to audio.
|
||||||
|
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
||||||
|
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||||
|
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||||
|
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
||||||
|
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||||
|
Smaller chunk_size uses less memory, but more compute.
|
||||||
|
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||||
|
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||||
|
'''
|
||||||
|
if not chunked:
|
||||||
|
# default behavior. Decode the entire latent in parallel
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
else:
|
||||||
|
# chunked decoding
|
||||||
|
hop_size = chunk_size - overlap
|
||||||
|
total_size = latents.shape[2]
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||||
|
chunk = latents[:,:,i:i+chunk_size]
|
||||||
|
chunks.append(chunk)
|
||||||
|
if i+chunk_size != total_size:
|
||||||
|
# Final chunk
|
||||||
|
chunk = latents[:,:,-chunk_size:]
|
||||||
|
chunks.append(chunk)
|
||||||
|
chunks = torch.stack(chunks)
|
||||||
|
num_chunks = chunks.shape[0]
|
||||||
|
# samples_per_latent is just the downsampling ratio
|
||||||
|
samples_per_latent = self.downsampling_ratio
|
||||||
|
# Create an empty waveform, we will populate it with chunks as decode them
|
||||||
|
y_size = total_size * samples_per_latent
|
||||||
|
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
|
||||||
|
for i in range(num_chunks):
|
||||||
|
x_chunk = chunks[i,:]
|
||||||
|
# decode the chunk
|
||||||
|
y_chunk = self.decode(x_chunk)
|
||||||
|
# figure out where to put the audio along the time domain
|
||||||
|
if i == num_chunks-1:
|
||||||
|
# final chunk always goes at the end
|
||||||
|
t_end = y_size
|
||||||
|
t_start = t_end - y_chunk.shape[2]
|
||||||
|
else:
|
||||||
|
t_start = i * hop_size * samples_per_latent
|
||||||
|
t_end = t_start + chunk_size * samples_per_latent
|
||||||
|
# remove the edges of the overlaps
|
||||||
|
ol = (overlap//2) * samples_per_latent
|
||||||
|
chunk_start = 0
|
||||||
|
chunk_end = y_chunk.shape[2]
|
||||||
|
if i > 0:
|
||||||
|
# no overlap for the start of the first chunk
|
||||||
|
t_start += ol
|
||||||
|
chunk_start += ol
|
||||||
|
if i < num_chunks-1:
|
||||||
|
# no overlap for the end of the last chunk
|
||||||
|
t_end -= ol
|
||||||
|
chunk_end -= ol
|
||||||
|
# paste the chunked audio into our y_final output audio
|
||||||
|
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||||
|
return y_final
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionAutoencoder(AudioAutoencoder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
diffusion: ConditionedDiffusionModel,
|
||||||
|
diffusion_downsampling_ratio,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.diffusion = diffusion
|
||||||
|
|
||||||
|
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
|
||||||
|
|
||||||
|
if self.encoder is not None:
|
||||||
|
# Shrink the initial encoder parameters to avoid saturated latents
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.encoder.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def decode(self, latents, steps=100):
|
||||||
|
|
||||||
|
upsampled_length = latents.shape[2] * self.downsampling_ratio
|
||||||
|
|
||||||
|
if self.bottleneck is not None:
|
||||||
|
latents = self.bottleneck.decode(latents)
|
||||||
|
|
||||||
|
if self.decoder is not None:
|
||||||
|
latents = self.decode(latents)
|
||||||
|
|
||||||
|
# Upsample latents to match diffusion length
|
||||||
|
if latents.shape[2] != upsampled_length:
|
||||||
|
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
|
||||||
|
|
||||||
|
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
|
||||||
|
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
|
||||||
|
|
||||||
|
if self.pretransform is not None:
|
||||||
|
if self.pretransform.enable_grad:
|
||||||
|
decoded = self.pretransform.decode(decoded)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
decoded = self.pretransform.decode(decoded)
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
# AE factories
|
||||||
|
|
||||||
|
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
||||||
|
encoder_type = encoder_config.get("type", None)
|
||||||
|
assert encoder_type is not None, "Encoder type must be specified"
|
||||||
|
|
||||||
|
if encoder_type == "oobleck":
|
||||||
|
encoder = OobleckEncoder(
|
||||||
|
**encoder_config["config"]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif encoder_type == "seanet":
|
||||||
|
from encodec.modules import SEANetEncoder
|
||||||
|
seanet_encoder_config = encoder_config["config"]
|
||||||
|
|
||||||
|
#SEANet encoder expects strides in reverse order
|
||||||
|
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
|
||||||
|
encoder = SEANetEncoder(
|
||||||
|
**seanet_encoder_config
|
||||||
|
)
|
||||||
|
elif encoder_type == "dac":
|
||||||
|
dac_config = encoder_config["config"]
|
||||||
|
|
||||||
|
encoder = DACEncoderWrapper(**dac_config)
|
||||||
|
elif encoder_type == "local_attn":
|
||||||
|
from .local_attention import TransformerEncoder1D
|
||||||
|
|
||||||
|
local_attn_config = encoder_config["config"]
|
||||||
|
|
||||||
|
encoder = TransformerEncoder1D(
|
||||||
|
**local_attn_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown encoder type {encoder_type}")
|
||||||
|
|
||||||
|
requires_grad = encoder_config.get("requires_grad", True)
|
||||||
|
if not requires_grad:
|
||||||
|
for param in encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
||||||
|
decoder_type = decoder_config.get("type", None)
|
||||||
|
assert decoder_type is not None, "Decoder type must be specified"
|
||||||
|
|
||||||
|
if decoder_type == "oobleck":
|
||||||
|
decoder = OobleckDecoder(
|
||||||
|
**decoder_config["config"]
|
||||||
|
)
|
||||||
|
elif decoder_type == "seanet":
|
||||||
|
from encodec.modules import SEANetDecoder
|
||||||
|
|
||||||
|
decoder = SEANetDecoder(
|
||||||
|
**decoder_config["config"]
|
||||||
|
)
|
||||||
|
elif decoder_type == "dac":
|
||||||
|
dac_config = decoder_config["config"]
|
||||||
|
|
||||||
|
decoder = DACDecoderWrapper(**dac_config)
|
||||||
|
elif decoder_type == "local_attn":
|
||||||
|
from .local_attention import TransformerDecoder1D
|
||||||
|
|
||||||
|
local_attn_config = decoder_config["config"]
|
||||||
|
|
||||||
|
decoder = TransformerDecoder1D(
|
||||||
|
**local_attn_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown decoder type {decoder_type}")
|
||||||
|
|
||||||
|
requires_grad = decoder_config.get("requires_grad", True)
|
||||||
|
if not requires_grad:
|
||||||
|
for param in decoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
def create_autoencoder_from_config(config: Dict[str, Any]):
|
||||||
|
|
||||||
|
ae_config = config["model"]
|
||||||
|
|
||||||
|
encoder = create_encoder_from_config(ae_config["encoder"])
|
||||||
|
decoder = create_decoder_from_config(ae_config["decoder"])
|
||||||
|
|
||||||
|
bottleneck = ae_config.get("bottleneck", None)
|
||||||
|
|
||||||
|
latent_dim = ae_config.get("latent_dim", None)
|
||||||
|
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||||
|
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
||||||
|
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||||
|
io_channels = ae_config.get("io_channels", None)
|
||||||
|
assert io_channels is not None, "io_channels must be specified in model config"
|
||||||
|
sample_rate = config.get("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||||
|
|
||||||
|
in_channels = ae_config.get("in_channels", None)
|
||||||
|
out_channels = ae_config.get("out_channels", None)
|
||||||
|
|
||||||
|
pretransform = ae_config.get("pretransform", None)
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||||
|
|
||||||
|
if bottleneck is not None:
|
||||||
|
bottleneck = create_bottleneck_from_config(bottleneck)
|
||||||
|
|
||||||
|
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
||||||
|
|
||||||
|
return AudioAutoencoder(
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
io_channels=io_channels,
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
downsampling_ratio=downsampling_ratio,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
bottleneck=bottleneck,
|
||||||
|
pretransform=pretransform,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
soft_clip=soft_clip
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_diffAE_from_config(config: Dict[str, Any]):
|
||||||
|
|
||||||
|
diffae_config = config["model"]
|
||||||
|
|
||||||
|
if "encoder" in diffae_config:
|
||||||
|
encoder = create_encoder_from_config(diffae_config["encoder"])
|
||||||
|
else:
|
||||||
|
encoder = None
|
||||||
|
|
||||||
|
if "decoder" in diffae_config:
|
||||||
|
decoder = create_decoder_from_config(diffae_config["decoder"])
|
||||||
|
else:
|
||||||
|
decoder = None
|
||||||
|
|
||||||
|
diffusion_model_type = diffae_config["diffusion"]["type"]
|
||||||
|
|
||||||
|
if diffusion_model_type == "DAU1d":
|
||||||
|
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||||
|
elif diffusion_model_type == "adp_1d":
|
||||||
|
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||||
|
elif diffusion_model_type == "dit":
|
||||||
|
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
|
||||||
|
|
||||||
|
latent_dim = diffae_config.get("latent_dim", None)
|
||||||
|
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||||
|
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
|
||||||
|
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||||
|
io_channels = diffae_config.get("io_channels", None)
|
||||||
|
assert io_channels is not None, "io_channels must be specified in model config"
|
||||||
|
sample_rate = config.get("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||||
|
|
||||||
|
bottleneck = diffae_config.get("bottleneck", None)
|
||||||
|
|
||||||
|
pretransform = diffae_config.get("pretransform", None)
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||||
|
|
||||||
|
if bottleneck is not None:
|
||||||
|
bottleneck = create_bottleneck_from_config(bottleneck)
|
||||||
|
|
||||||
|
diffusion_downsampling_ratio = None,
|
||||||
|
|
||||||
|
if diffusion_model_type == "DAU1d":
|
||||||
|
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
||||||
|
elif diffusion_model_type == "adp_1d":
|
||||||
|
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
|
||||||
|
elif diffusion_model_type == "dit":
|
||||||
|
diffusion_downsampling_ratio = 1
|
||||||
|
|
||||||
|
return DiffusionAutoencoder(
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
diffusion=diffusion,
|
||||||
|
io_channels=io_channels,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
downsampling_ratio=downsampling_ratio,
|
||||||
|
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
|
||||||
|
bottleneck=bottleneck,
|
||||||
|
pretransform=pretransform
|
||||||
|
)
|
||||||
339
stable_audio_tools/models/blocks.py
Normal file
339
stable_audio_tools/models/blocks.py
Normal file
|
|
@ -0,0 +1,339 @@
|
||||||
|
from functools import reduce
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from torch.backends.cuda import sdp_kernel
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from dac.nn.layers import Snake1d
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, main, skip=None):
|
||||||
|
super().__init__()
|
||||||
|
self.main = nn.Sequential(*main)
|
||||||
|
self.skip = skip if skip else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.main(input) + self.skip(input)
|
||||||
|
|
||||||
|
class ResConvBlock(ResidualBlock):
|
||||||
|
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
|
||||||
|
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
|
||||||
|
super().__init__([
|
||||||
|
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||||
|
nn.GroupNorm(1, c_mid),
|
||||||
|
Snake1d(c_mid) if use_snake else nn.GELU(),
|
||||||
|
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||||
|
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
|
||||||
|
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
|
||||||
|
], skip)
|
||||||
|
|
||||||
|
class SelfAttention1d(nn.Module):
|
||||||
|
def __init__(self, c_in, n_head=1, dropout_rate=0.):
|
||||||
|
super().__init__()
|
||||||
|
assert c_in % n_head == 0
|
||||||
|
self.norm = nn.GroupNorm(1, c_in)
|
||||||
|
self.n_head = n_head
|
||||||
|
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
|
||||||
|
self.out_proj = nn.Conv1d(c_in, c_in, 1)
|
||||||
|
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||||
|
|
||||||
|
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
||||||
|
|
||||||
|
if not self.use_flash:
|
||||||
|
return
|
||||||
|
|
||||||
|
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||||
|
|
||||||
|
if device_properties.major == 8 and device_properties.minor == 0:
|
||||||
|
# Use flash attention for A100 GPUs
|
||||||
|
self.sdp_kernel_config = (True, False, False)
|
||||||
|
else:
|
||||||
|
# Don't use flash attention for other GPUs
|
||||||
|
self.sdp_kernel_config = (False, True, True)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
n, c, s = input.shape
|
||||||
|
qkv = self.qkv_proj(self.norm(input))
|
||||||
|
qkv = qkv.view(
|
||||||
|
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
scale = k.shape[3]**-0.25
|
||||||
|
|
||||||
|
if self.use_flash:
|
||||||
|
with sdp_kernel(*self.sdp_kernel_config):
|
||||||
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
|
||||||
|
else:
|
||||||
|
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
||||||
|
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
|
||||||
|
|
||||||
|
|
||||||
|
return input + self.dropout(self.out_proj(y))
|
||||||
|
|
||||||
|
class SkipBlock(nn.Module):
|
||||||
|
def __init__(self, *main):
|
||||||
|
super().__init__()
|
||||||
|
self.main = nn.Sequential(*main)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.cat([self.main(input), input], dim=1)
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, std=1.):
|
||||||
|
super().__init__()
|
||||||
|
assert out_features % 2 == 0
|
||||||
|
self.weight = nn.Parameter(torch.randn(
|
||||||
|
[out_features // 2, in_features]) * std)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
f = 2 * math.pi * input @ self.weight.T
|
||||||
|
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||||
|
|
||||||
|
def expand_to_planes(input, shape):
|
||||||
|
return input[..., None].repeat([1, 1, shape[2]])
|
||||||
|
|
||||||
|
_kernels = {
|
||||||
|
'linear':
|
||||||
|
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||||
|
'cubic':
|
||||||
|
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
||||||
|
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||||
|
'lanczos3':
|
||||||
|
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
||||||
|
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
||||||
|
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
||||||
|
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
||||||
|
}
|
||||||
|
|
||||||
|
class Downsample1d(nn.Module):
|
||||||
|
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||||
|
super().__init__()
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
kernel_1d = torch.tensor(_kernels[kernel])
|
||||||
|
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||||
|
self.register_buffer('kernel', kernel_1d)
|
||||||
|
self.channels_last = channels_last
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.channels_last:
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
|
||||||
|
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||||
|
indices = torch.arange(x.shape[1], device=x.device)
|
||||||
|
weight[indices, indices] = self.kernel.to(weight)
|
||||||
|
x = F.conv1d(x, weight, stride=2)
|
||||||
|
if self.channels_last:
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample1d(nn.Module):
|
||||||
|
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||||
|
super().__init__()
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||||
|
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||||
|
self.register_buffer('kernel', kernel_1d)
|
||||||
|
self.channels_last = channels_last
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.channels_last:
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||||
|
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||||
|
indices = torch.arange(x.shape[1], device=x.device)
|
||||||
|
weight[indices, indices] = self.kernel.to(weight)
|
||||||
|
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||||
|
if self.channels_last:
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def Downsample1d_2(
|
||||||
|
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
|
||||||
|
) -> nn.Module:
|
||||||
|
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
|
||||||
|
|
||||||
|
return nn.Conv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=factor * kernel_multiplier + 1,
|
||||||
|
stride=factor,
|
||||||
|
padding=factor * (kernel_multiplier // 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def Upsample1d_2(
|
||||||
|
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
|
||||||
|
) -> nn.Module:
|
||||||
|
|
||||||
|
if factor == 1:
|
||||||
|
return nn.Conv1d(
|
||||||
|
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_nearest:
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=factor, mode="nearest"),
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return nn.ConvTranspose1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=factor * 2,
|
||||||
|
stride=factor,
|
||||||
|
padding=factor // 2 + factor % 2,
|
||||||
|
output_padding=factor % 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def zero_init(layer):
|
||||||
|
nn.init.zeros_(layer.weight)
|
||||||
|
if layer.bias is not None:
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
def rms_norm(x, scale, eps):
|
||||||
|
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
||||||
|
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
||||||
|
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
||||||
|
return x * scale.to(x.dtype)
|
||||||
|
|
||||||
|
#rms_norm = torch.compile(rms_norm)
|
||||||
|
|
||||||
|
class AdaRMSNorm(nn.Module):
|
||||||
|
def __init__(self, features, cond_features, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"eps={self.eps},"
|
||||||
|
|
||||||
|
def forward(self, x, cond):
|
||||||
|
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
|
||||||
|
|
||||||
|
def normalize(x, eps=1e-4):
|
||||||
|
dim = list(range(1, x.ndim))
|
||||||
|
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
|
||||||
|
alpha = np.sqrt(n.numel() / x.numel())
|
||||||
|
return x / torch.add(eps, n, alpha=alpha)
|
||||||
|
|
||||||
|
class ForcedWNConv1d(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=1):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.training:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.weight.copy_(normalize(self.weight))
|
||||||
|
|
||||||
|
fan_in = self.weight[0].numel()
|
||||||
|
|
||||||
|
w = normalize(self.weight) / math.sqrt(fan_in)
|
||||||
|
|
||||||
|
return F.conv1d(x, w, padding='same')
|
||||||
|
|
||||||
|
# Kernels
|
||||||
|
|
||||||
|
use_compile = True
|
||||||
|
|
||||||
|
def compile(function, *args, **kwargs):
|
||||||
|
if not use_compile:
|
||||||
|
return function
|
||||||
|
try:
|
||||||
|
return torch.compile(function, *args, **kwargs)
|
||||||
|
except RuntimeError:
|
||||||
|
return function
|
||||||
|
|
||||||
|
|
||||||
|
@compile
|
||||||
|
def linear_geglu(x, weight, bias=None):
|
||||||
|
x = x @ weight.mT
|
||||||
|
if bias is not None:
|
||||||
|
x = x + bias
|
||||||
|
x, gate = x.chunk(2, dim=-1)
|
||||||
|
return x * F.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
@compile
|
||||||
|
def rms_norm(x, scale, eps):
|
||||||
|
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
||||||
|
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
||||||
|
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
||||||
|
return x * scale.to(x.dtype)
|
||||||
|
|
||||||
|
# Layers
|
||||||
|
|
||||||
|
class LinearGEGLU(nn.Linear):
|
||||||
|
def __init__(self, in_features, out_features, bias=True):
|
||||||
|
super().__init__(in_features, out_features * 2, bias=bias)
|
||||||
|
self.out_features = out_features
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return linear_geglu(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, shape, fix_scale = False, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if fix_scale:
|
||||||
|
self.register_buffer("scale", torch.ones(shape))
|
||||||
|
else:
|
||||||
|
self.scale = nn.Parameter(torch.ones(shape))
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return rms_norm(x, self.scale, self.eps)
|
||||||
|
|
||||||
|
def snake_beta(x, alpha, beta):
|
||||||
|
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# snake_beta = torch.compile(snake_beta)
|
||||||
|
# except RuntimeError:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||||
|
# License available in LICENSES/LICENSE_NVIDIA.txt
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||||
|
super(SnakeBeta, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||||
|
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # linear scale alphas initialized to ones
|
||||||
|
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||||
|
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.beta.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
|
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
beta = torch.exp(beta)
|
||||||
|
x = snake_beta(x, alpha, beta)
|
||||||
|
|
||||||
|
return x
|
||||||
355
stable_audio_tools/models/bottleneck.py
Normal file
355
stable_audio_tools/models/bottleneck.py
Normal file
|
|
@ -0,0 +1,355 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from vector_quantize_pytorch import ResidualVQ, FSQ
|
||||||
|
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
def __init__(self, is_discrete: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_discrete = is_discrete
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class DiscreteBottleneck(Bottleneck):
|
||||||
|
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
||||||
|
super().__init__(is_discrete=True)
|
||||||
|
|
||||||
|
self.num_quantizers = num_quantizers
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.tokens_id = tokens_id
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class TanhBottleneck(Bottleneck):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(is_discrete=False)
|
||||||
|
self.tanh = nn.Tanh()
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def vae_sample(mean, scale):
|
||||||
|
stdev = nn.functional.softplus(scale) + 1e-4
|
||||||
|
var = stdev * stdev
|
||||||
|
logvar = torch.log(var)
|
||||||
|
latents = torch.randn_like(mean) * stdev + mean
|
||||||
|
|
||||||
|
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||||
|
|
||||||
|
return latents, kl
|
||||||
|
|
||||||
|
class VAEBottleneck(Bottleneck):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(is_discrete=False)
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
mean, scale = x.chunk(2, dim=1)
|
||||||
|
|
||||||
|
x, kl = vae_sample(mean, scale)
|
||||||
|
|
||||||
|
info["kl"] = kl
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def compute_mean_kernel(x, y):
|
||||||
|
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
||||||
|
return torch.exp(-kernel_input).mean()
|
||||||
|
|
||||||
|
def compute_mmd(latents):
|
||||||
|
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
||||||
|
noise = torch.randn_like(latents_reshaped)
|
||||||
|
|
||||||
|
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
||||||
|
noise_kernel = compute_mean_kernel(noise, noise)
|
||||||
|
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
||||||
|
|
||||||
|
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
||||||
|
return mmd.mean()
|
||||||
|
|
||||||
|
class WassersteinBottleneck(Bottleneck):
|
||||||
|
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
||||||
|
super().__init__(is_discrete=False)
|
||||||
|
|
||||||
|
self.noise_augment_dim = noise_augment_dim
|
||||||
|
self.bypass_mmd = bypass_mmd
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
if self.training and return_info:
|
||||||
|
if self.bypass_mmd:
|
||||||
|
mmd = torch.tensor(0.0)
|
||||||
|
else:
|
||||||
|
mmd = compute_mmd(x)
|
||||||
|
|
||||||
|
info["mmd"] = mmd
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
|
||||||
|
if self.noise_augment_dim > 0:
|
||||||
|
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||||
|
x.shape[-1]).type_as(x)
|
||||||
|
x = torch.cat([x, noise], dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class L2Bottleneck(Bottleneck):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(is_discrete=False)
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
x = F.normalize(x, dim=1)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return F.normalize(x, dim=1)
|
||||||
|
|
||||||
|
class RVQBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, **quantizer_kwargs):
|
||||||
|
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||||
|
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||||
|
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x, indices, loss = self.quantizer(x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
|
||||||
|
info["quantizer_indices"] = indices
|
||||||
|
info["quantizer_loss"] = loss.mean()
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
class RVQVAEBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, **quantizer_kwargs):
|
||||||
|
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||||
|
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||||
|
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
x, kl = vae_sample(*x.chunk(2, dim=1))
|
||||||
|
|
||||||
|
info["kl"] = kl
|
||||||
|
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x, indices, loss = self.quantizer(x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
|
||||||
|
info["quantizer_indices"] = indices
|
||||||
|
info["quantizer_loss"] = loss.mean()
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
class DACRVQBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
|
||||||
|
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||||
|
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||||
|
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||||
|
self.quantize_on_decode = quantize_on_decode
|
||||||
|
self.noise_augment_dim = noise_augment_dim
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
info["pre_quantizer"] = x
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
return x, info if return_info else x
|
||||||
|
|
||||||
|
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"z": z,
|
||||||
|
"codes": codes,
|
||||||
|
"latents": latents,
|
||||||
|
"vq/commitment_loss": commitment_loss,
|
||||||
|
"vq/codebook_loss": codebook_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
output["vq/commitment_loss"] /= self.num_quantizers
|
||||||
|
output["vq/codebook_loss"] /= self.num_quantizers
|
||||||
|
|
||||||
|
info.update(output)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output["z"], info
|
||||||
|
|
||||||
|
return output["z"]
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
x = self.quantizer(x)[0]
|
||||||
|
|
||||||
|
if self.noise_augment_dim > 0:
|
||||||
|
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||||
|
x.shape[-1]).type_as(x)
|
||||||
|
x = torch.cat([x, noise], dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
latents, _, _ = self.quantizer.from_codes(codes)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
||||||
|
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||||
|
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||||
|
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||||
|
self.quantize_on_decode = quantize_on_decode
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, n_quantizers: int = None):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
mean, scale = x.chunk(2, dim=1)
|
||||||
|
|
||||||
|
x, kl = vae_sample(mean, scale)
|
||||||
|
|
||||||
|
info["pre_quantizer"] = x
|
||||||
|
info["kl"] = kl
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
return x, info if return_info else x
|
||||||
|
|
||||||
|
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"z": z,
|
||||||
|
"codes": codes,
|
||||||
|
"latents": latents,
|
||||||
|
"vq/commitment_loss": commitment_loss,
|
||||||
|
"vq/codebook_loss": codebook_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
output["vq/commitment_loss"] /= self.num_quantizers
|
||||||
|
output["vq/codebook_loss"] /= self.num_quantizers
|
||||||
|
|
||||||
|
info.update(output)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output["z"], info
|
||||||
|
|
||||||
|
return output["z"]
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
x = self.quantizer(x)[0]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
latents, _, _ = self.quantizer.from_codes(codes)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
class FSQBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, noise_augment_dim=0, **kwargs):
|
||||||
|
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
|
||||||
|
|
||||||
|
self.noise_augment_dim = noise_augment_dim
|
||||||
|
|
||||||
|
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x, indices = self.quantizer(x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
|
||||||
|
x = x.to(orig_dtype)
|
||||||
|
|
||||||
|
# Reorder indices to match the expected format
|
||||||
|
indices = rearrange(indices, "b n q -> b q n")
|
||||||
|
|
||||||
|
info["quantizer_indices"] = indices
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
|
||||||
|
if self.noise_augment_dim > 0:
|
||||||
|
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||||
|
x.shape[-1]).type_as(x)
|
||||||
|
x = torch.cat([x, noise], dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens, **kwargs):
|
||||||
|
latents = self.quantizer.indices_to_codes(tokens)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
545
stable_audio_tools/models/codebook_patterns.py
Normal file
545
stable_audio_tools/models/codebook_patterns.py
Normal file
|
|
@ -0,0 +1,545 @@
|
||||||
|
# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License
|
||||||
|
# License available in LICENSES/LICENSE_META.txt
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
import logging
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import torch
|
||||||
|
|
||||||
|
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
|
||||||
|
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Pattern:
|
||||||
|
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
||||||
|
|
||||||
|
The codebook pattern consists in a layout, defining for each sequence step
|
||||||
|
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
||||||
|
The first item of the pattern is always an empty list in order to properly insert a special token
|
||||||
|
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
|
||||||
|
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
||||||
|
|
||||||
|
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
||||||
|
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
||||||
|
to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
|
||||||
|
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
||||||
|
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
||||||
|
is returned along with a mask indicating valid tokens.
|
||||||
|
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
||||||
|
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
||||||
|
to fill and specify invalid positions if needed.
|
||||||
|
See the dedicated methods for more details.
|
||||||
|
"""
|
||||||
|
# Pattern layout, for each sequence step, we have a list of coordinates
|
||||||
|
# corresponding to the original codebook timestep and position.
|
||||||
|
# The first list is always an empty list in order to properly insert
|
||||||
|
# a special token to start with.
|
||||||
|
layout: PatternLayout
|
||||||
|
timesteps: int
|
||||||
|
n_q: int
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert len(self.layout) > 0
|
||||||
|
self._validate_layout()
|
||||||
|
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
||||||
|
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
||||||
|
logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
||||||
|
|
||||||
|
def _validate_layout(self):
|
||||||
|
"""Runs checks on the layout to ensure a valid pattern is defined.
|
||||||
|
A pattern is considered invalid if:
|
||||||
|
- Multiple timesteps for a same codebook are defined in the same sequence step
|
||||||
|
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
|
||||||
|
(this would mean that we have future timesteps before past timesteps).
|
||||||
|
"""
|
||||||
|
q_timesteps = {q: 0 for q in range(self.n_q)}
|
||||||
|
for s, seq_coords in enumerate(self.layout):
|
||||||
|
if len(seq_coords) > 0:
|
||||||
|
qs = set()
|
||||||
|
for coord in seq_coords:
|
||||||
|
qs.add(coord.q)
|
||||||
|
last_q_timestep = q_timesteps[coord.q]
|
||||||
|
assert coord.t >= last_q_timestep, \
|
||||||
|
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
|
||||||
|
q_timesteps[coord.q] = coord.t
|
||||||
|
# each sequence step contains at max 1 coordinate per codebook
|
||||||
|
assert len(qs) == len(seq_coords), \
|
||||||
|
f"Multiple entries for a same codebook are found at step {s}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_sequence_steps(self):
|
||||||
|
return len(self.layout) - 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_delay(self):
|
||||||
|
max_t_in_seq_coords = 0
|
||||||
|
for seq_coords in self.layout[1:]:
|
||||||
|
for coords in seq_coords:
|
||||||
|
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
||||||
|
return max_t_in_seq_coords - self.timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid_layout(self):
|
||||||
|
valid_step = len(self.layout) - self.max_delay
|
||||||
|
return self.layout[:valid_step]
|
||||||
|
|
||||||
|
def starts_with_special_token(self):
|
||||||
|
return self.layout[0] == []
|
||||||
|
|
||||||
|
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
||||||
|
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
||||||
|
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
||||||
|
and the actual codebook coordinates.
|
||||||
|
"""
|
||||||
|
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
|
||||||
|
if q is not None:
|
||||||
|
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
|
||||||
|
coords = []
|
||||||
|
for s, seq_codes in enumerate(self.layout):
|
||||||
|
for code in seq_codes:
|
||||||
|
if code.t == t and (q is None or code.q == q):
|
||||||
|
coords.append((s, code))
|
||||||
|
return coords
|
||||||
|
|
||||||
|
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
|
||||||
|
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
||||||
|
|
||||||
|
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
|
||||||
|
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
||||||
|
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
||||||
|
|
||||||
|
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
|
||||||
|
device: tp.Union[torch.device, str] = 'cpu'):
|
||||||
|
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timesteps (int): Maximum number of timesteps steps to consider.
|
||||||
|
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
||||||
|
device (torch.device or str): Device for created tensors.
|
||||||
|
Returns:
|
||||||
|
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
||||||
|
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
||||||
|
"""
|
||||||
|
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
||||||
|
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
|
||||||
|
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
||||||
|
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
||||||
|
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
||||||
|
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
||||||
|
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
|
||||||
|
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
|
||||||
|
# fill indexes with last sequence step value that will correspond to our special token
|
||||||
|
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
|
||||||
|
# which will correspond to the index: n_q * timesteps
|
||||||
|
indexes[:] = n_q * timesteps
|
||||||
|
# iterate over the pattern and fill scattered indexes and mask
|
||||||
|
for s, sequence_coords in enumerate(ref_layout):
|
||||||
|
for coords in sequence_coords:
|
||||||
|
if coords.t < timesteps:
|
||||||
|
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
||||||
|
mask[coords.q, s] = 1
|
||||||
|
indexes = torch.from_numpy(indexes).to(device)
|
||||||
|
mask = torch.from_numpy(mask).to(device)
|
||||||
|
return indexes, mask
|
||||||
|
|
||||||
|
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
||||||
|
"""Build sequence corresponding to the pattern from the input tensor z.
|
||||||
|
The sequence is built using up to sequence_steps if specified, and non-pattern
|
||||||
|
coordinates are filled with the special token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
|
||||||
|
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
|
||||||
|
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
||||||
|
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
||||||
|
Returns:
|
||||||
|
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
|
||||||
|
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
|
||||||
|
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
|
||||||
|
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
|
||||||
|
"""
|
||||||
|
B, K, T = z.shape
|
||||||
|
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
||||||
|
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
||||||
|
)
|
||||||
|
z = z.view(B, -1)
|
||||||
|
# we append the special token as the last index of our flattened z tensor
|
||||||
|
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
||||||
|
values = z[:, indexes.view(-1)]
|
||||||
|
values = values.view(B, K, indexes.shape[-1])
|
||||||
|
return values, indexes, mask
|
||||||
|
|
||||||
|
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
|
||||||
|
keep_only_valid_steps: bool = False,
|
||||||
|
is_model_output: bool = False,
|
||||||
|
device: tp.Union[torch.device, str] = 'cpu'):
|
||||||
|
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
||||||
|
from interleaving pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_steps (int): Sequence steps.
|
||||||
|
n_q (int): Number of codebooks.
|
||||||
|
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
||||||
|
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
||||||
|
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
||||||
|
device (torch.device or str): Device for created tensors.
|
||||||
|
Returns:
|
||||||
|
indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
|
||||||
|
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
||||||
|
"""
|
||||||
|
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
||||||
|
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
|
||||||
|
timesteps = self.timesteps
|
||||||
|
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
||||||
|
assert sequence_steps <= len(ref_layout), \
|
||||||
|
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
||||||
|
|
||||||
|
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
||||||
|
if is_model_output and self.starts_with_special_token():
|
||||||
|
ref_layout = ref_layout[1:]
|
||||||
|
|
||||||
|
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
||||||
|
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
|
||||||
|
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
|
||||||
|
# fill indexes with last sequence step value that will correspond to our special token
|
||||||
|
indexes[:] = n_q * sequence_steps
|
||||||
|
for s, sequence_codes in enumerate(ref_layout):
|
||||||
|
if s < sequence_steps:
|
||||||
|
for code in sequence_codes:
|
||||||
|
if code.t < timesteps:
|
||||||
|
indexes[code.q, code.t] = s + code.q * sequence_steps
|
||||||
|
mask[code.q, code.t] = 1
|
||||||
|
indexes = torch.from_numpy(indexes).to(device)
|
||||||
|
mask = torch.from_numpy(mask).to(device)
|
||||||
|
return indexes, mask
|
||||||
|
|
||||||
|
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
||||||
|
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
|
||||||
|
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
|
||||||
|
are filled with the special token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
||||||
|
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
||||||
|
Returns:
|
||||||
|
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
||||||
|
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
|
||||||
|
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
||||||
|
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
||||||
|
"""
|
||||||
|
B, K, S = s.shape
|
||||||
|
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
||||||
|
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
||||||
|
)
|
||||||
|
s = s.view(B, -1)
|
||||||
|
# we append the special token as the last index of our flattened z tensor
|
||||||
|
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
||||||
|
values = s[:, indexes.view(-1)]
|
||||||
|
values = values.view(B, K, indexes.shape[-1])
|
||||||
|
return values, indexes, mask
|
||||||
|
|
||||||
|
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
|
||||||
|
"""Revert model logits obtained on a sequence built from the pattern
|
||||||
|
back to a tensor matching the original sequence.
|
||||||
|
|
||||||
|
This method is similar to ``revert_pattern_sequence`` with the following specificities:
|
||||||
|
1. It is designed to work with the extra cardinality dimension
|
||||||
|
2. We return the logits for the first sequence item that matches the special_token and
|
||||||
|
which matching target in the original sequence is the first item of the sequence,
|
||||||
|
while we skip the last logits as there is no matching target
|
||||||
|
"""
|
||||||
|
B, card, K, S = logits.shape
|
||||||
|
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
||||||
|
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
|
||||||
|
)
|
||||||
|
logits = logits.reshape(B, card, -1)
|
||||||
|
# we append the special token as the last index of our flattened z tensor
|
||||||
|
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
|
||||||
|
values = logits[:, :, indexes.view(-1)]
|
||||||
|
values = values.view(B, card, K, indexes.shape[-1])
|
||||||
|
return values, indexes, mask
|
||||||
|
|
||||||
|
|
||||||
|
class CodebooksPatternProvider(ABC):
|
||||||
|
"""Abstraction around providing pattern for interleaving codebooks.
|
||||||
|
|
||||||
|
The CodebooksPatternProvider abstraction allows to implement various strategies to
|
||||||
|
define interleaving pattern of sequences composed of multiple codebooks. For a given
|
||||||
|
number of codebooks `n_q`, the pattern provider can generate a specified pattern
|
||||||
|
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
|
||||||
|
can be used to construct a new sequence from the original codes respecting the specified
|
||||||
|
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
|
||||||
|
being a tuple with the original timestep and codebook to build the new sequence.
|
||||||
|
Note that all patterns must start with an empty list that is then used to insert a first
|
||||||
|
sequence step of special tokens in the newly generated sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_q (int): number of codebooks.
|
||||||
|
cached (bool): if True, patterns for a given length are cached. In general
|
||||||
|
that should be true for efficiency reason to avoid synchronization points.
|
||||||
|
"""
|
||||||
|
def __init__(self, n_q: int, cached: bool = True):
|
||||||
|
assert n_q > 0
|
||||||
|
self.n_q = n_q
|
||||||
|
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_pattern(self, timesteps: int) -> Pattern:
|
||||||
|
"""Builds pattern with specific interleaving between codebooks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timesteps (int): Total number of timesteps.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class DelayedPatternProvider(CodebooksPatternProvider):
|
||||||
|
"""Provider for delayed pattern across delayed codebooks.
|
||||||
|
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
||||||
|
from different timesteps.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
|
||||||
|
[[1, 2, 3, 4],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[1, 2, 3, 4]]
|
||||||
|
The resulting sequence obtained from the returned pattern is:
|
||||||
|
[[S, 1, 2, 3, 4],
|
||||||
|
[S, S, 1, 2, 3],
|
||||||
|
[S, S, S, 1, 2]]
|
||||||
|
(with S being a special token)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_q (int): Number of codebooks.
|
||||||
|
delays (list of int, optional): Delay for each of the codebooks.
|
||||||
|
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
||||||
|
flatten_first (int): Flatten the first N timesteps.
|
||||||
|
empty_initial (int): Prepend with N empty list of coordinates.
|
||||||
|
"""
|
||||||
|
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
|
||||||
|
flatten_first: int = 0, empty_initial: int = 0):
|
||||||
|
super().__init__(n_q)
|
||||||
|
if delays is None:
|
||||||
|
delays = list(range(n_q))
|
||||||
|
self.delays = delays
|
||||||
|
self.flatten_first = flatten_first
|
||||||
|
self.empty_initial = empty_initial
|
||||||
|
assert len(self.delays) == self.n_q
|
||||||
|
assert sorted(self.delays) == self.delays
|
||||||
|
|
||||||
|
def get_pattern(self, timesteps: int) -> Pattern:
|
||||||
|
omit_special_token = self.empty_initial < 0
|
||||||
|
out: PatternLayout = [] if omit_special_token else [[]]
|
||||||
|
max_delay = max(self.delays)
|
||||||
|
if self.empty_initial:
|
||||||
|
out += [[] for _ in range(self.empty_initial)]
|
||||||
|
if self.flatten_first:
|
||||||
|
for t in range(min(timesteps, self.flatten_first)):
|
||||||
|
for q in range(self.n_q):
|
||||||
|
out.append([LayoutCoord(t, q)])
|
||||||
|
for t in range(self.flatten_first, timesteps + max_delay):
|
||||||
|
v = []
|
||||||
|
for q, delay in enumerate(self.delays):
|
||||||
|
t_for_q = t - delay
|
||||||
|
if t_for_q >= self.flatten_first:
|
||||||
|
v.append(LayoutCoord(t_for_q, q))
|
||||||
|
out.append(v)
|
||||||
|
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelPatternProvider(DelayedPatternProvider):
|
||||||
|
"""Provider for parallel pattern across codebooks.
|
||||||
|
This pattern provider is a special case of the delayed pattern with actually no delay,
|
||||||
|
hence delays=repeat(0, n_q).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_q (int): Number of codebooks.
|
||||||
|
empty_initial (int): Prepend with N empty list of coordinates.
|
||||||
|
"""
|
||||||
|
def __init__(self, n_q: int, empty_initial: int = 0):
|
||||||
|
super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
|
||||||
|
|
||||||
|
|
||||||
|
class UnrolledPatternProvider(CodebooksPatternProvider):
|
||||||
|
"""Provider for unrolling codebooks pattern.
|
||||||
|
This pattern provider enables to represent the codebook flattened completely or only to some extend
|
||||||
|
while also specifying a given delay between the flattened codebooks representation, allowing to
|
||||||
|
unroll the codebooks in the sequence.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
1. Flattening of the codebooks.
|
||||||
|
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
|
||||||
|
taking n_q = 3 and timesteps = 4:
|
||||||
|
[[1, 2, 3, 4],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[1, 2, 3, 4]]
|
||||||
|
will result into:
|
||||||
|
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
|
||||||
|
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
||||||
|
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
||||||
|
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
|
||||||
|
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
|
||||||
|
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
|
||||||
|
[[1, 2, 3, 4],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[1, 2, 3, 4]]
|
||||||
|
will result into:
|
||||||
|
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
||||||
|
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
||||||
|
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
||||||
|
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
|
||||||
|
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
|
||||||
|
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
|
||||||
|
and delays = [0, 3, 3]:
|
||||||
|
[[1, 2, 3, 4],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[1, 2, 3, 4]]
|
||||||
|
will result into:
|
||||||
|
[[S, S, S, 1, S, 2, S, 3, S, 4],
|
||||||
|
[S, S, S, 1, S, 2, S, 3, S, 4],
|
||||||
|
[1, 2, 3, S, 4, S, 5, S, 6, S]]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_q (int): Number of codebooks.
|
||||||
|
flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
|
||||||
|
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
||||||
|
have n_q extra steps for each timestep.
|
||||||
|
delays (list of int, optional): Delay for each of the codebooks. If not defined,
|
||||||
|
no delay is added and therefore will default to [0] * ``n_q``.
|
||||||
|
Note that two codebooks that will be flattened to the same inner step
|
||||||
|
should have the same delay, otherwise the pattern is considered as invalid.
|
||||||
|
"""
|
||||||
|
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
|
||||||
|
|
||||||
|
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
|
||||||
|
delays: tp.Optional[tp.List[int]] = None):
|
||||||
|
super().__init__(n_q)
|
||||||
|
if flattening is None:
|
||||||
|
flattening = list(range(n_q))
|
||||||
|
if delays is None:
|
||||||
|
delays = [0] * n_q
|
||||||
|
assert len(flattening) == n_q
|
||||||
|
assert len(delays) == n_q
|
||||||
|
assert sorted(flattening) == flattening
|
||||||
|
assert sorted(delays) == delays
|
||||||
|
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
|
||||||
|
self.max_delay = max(delays)
|
||||||
|
|
||||||
|
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
|
||||||
|
"""Build a flattened codebooks representation as a dictionary of inner step
|
||||||
|
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
|
||||||
|
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
|
||||||
|
"""
|
||||||
|
flattened_codebooks: dict = {}
|
||||||
|
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
|
||||||
|
if inner_step not in flattened_codebooks:
|
||||||
|
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
|
||||||
|
else:
|
||||||
|
flat_codebook = flattened_codebooks[inner_step]
|
||||||
|
assert flat_codebook.delay == delay, (
|
||||||
|
"Delay and flattening between codebooks is inconsistent: ",
|
||||||
|
"two codebooks flattened to the same position should have the same delay."
|
||||||
|
)
|
||||||
|
flat_codebook.codebooks.append(q)
|
||||||
|
flattened_codebooks[inner_step] = flat_codebook
|
||||||
|
return flattened_codebooks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_inner_steps(self):
|
||||||
|
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
|
||||||
|
"""
|
||||||
|
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
|
||||||
|
|
||||||
|
def num_virtual_steps(self, timesteps: int) -> int:
|
||||||
|
return timesteps * self._num_inner_steps + 1
|
||||||
|
|
||||||
|
def get_pattern(self, timesteps: int) -> Pattern:
|
||||||
|
"""Builds pattern for delay across codebooks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timesteps (int): Total number of timesteps.
|
||||||
|
"""
|
||||||
|
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
||||||
|
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
||||||
|
indexed_out: list = [(-1, [])]
|
||||||
|
max_timesteps = timesteps + self.max_delay
|
||||||
|
for t in range(max_timesteps):
|
||||||
|
# for each timestep, we unroll the flattened codebooks,
|
||||||
|
# emitting the sequence step with the corresponding delay
|
||||||
|
for step in range(self._num_inner_steps):
|
||||||
|
if step in self._flattened_codebooks:
|
||||||
|
# we have codebooks at this virtual step to emit
|
||||||
|
step_codebooks = self._flattened_codebooks[step]
|
||||||
|
t_for_q = t + step_codebooks.delay
|
||||||
|
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
|
||||||
|
if t_for_q < max_timesteps and t < max_timesteps:
|
||||||
|
indexed_out.append((t_for_q, coords))
|
||||||
|
else:
|
||||||
|
# there is no codebook in this virtual step so we emit an empty list
|
||||||
|
indexed_out.append((t, []))
|
||||||
|
out = [coords for _, coords in sorted(indexed_out)]
|
||||||
|
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
||||||
|
|
||||||
|
|
||||||
|
class CoarseFirstPattern(CodebooksPatternProvider):
|
||||||
|
"""First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
|
||||||
|
potentially with delays.
|
||||||
|
|
||||||
|
..Warning:: You must always generate the full training duration at test time, for instance,
|
||||||
|
30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
|
||||||
|
location. This is due to the non causality of the remaining codebooks with respect to
|
||||||
|
the first ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_q (int): Number of codebooks.
|
||||||
|
delays (list of int, optional): Delay for each of the codebooks.
|
||||||
|
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
||||||
|
"""
|
||||||
|
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
||||||
|
super().__init__(n_q)
|
||||||
|
if delays is None:
|
||||||
|
delays = [0] * (n_q - 1)
|
||||||
|
self.delays = delays
|
||||||
|
assert len(self.delays) == self.n_q - 1
|
||||||
|
assert sorted(self.delays) == self.delays
|
||||||
|
|
||||||
|
def get_pattern(self, timesteps: int) -> Pattern:
|
||||||
|
out: PatternLayout = [[]]
|
||||||
|
for t in range(timesteps):
|
||||||
|
out.append([LayoutCoord(t, 0)])
|
||||||
|
max_delay = max(self.delays)
|
||||||
|
for t in range(timesteps + max_delay):
|
||||||
|
v = []
|
||||||
|
for q, delay in enumerate(self.delays):
|
||||||
|
t_for_q = t - delay
|
||||||
|
if t_for_q >= 0:
|
||||||
|
v.append(LayoutCoord(t_for_q, q + 1))
|
||||||
|
out.append(v)
|
||||||
|
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
||||||
|
|
||||||
|
|
||||||
|
class MusicLMPattern(CodebooksPatternProvider):
|
||||||
|
"""Almost MusicLM style pattern. This is equivalent to full flattening
|
||||||
|
but in a different order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_q (int): Number of codebooks.
|
||||||
|
group_by (int): Number of codebooks to group together.
|
||||||
|
"""
|
||||||
|
def __init__(self, n_q: int, group_by: int = 2):
|
||||||
|
super().__init__(n_q)
|
||||||
|
self.group_by = group_by
|
||||||
|
|
||||||
|
def get_pattern(self, timesteps: int) -> Pattern:
|
||||||
|
out: PatternLayout = [[]]
|
||||||
|
for offset in range(0, self.n_q, self.group_by):
|
||||||
|
for t in range(timesteps):
|
||||||
|
for q in range(offset, offset + self.group_by):
|
||||||
|
out.append([LayoutCoord(t, q)])
|
||||||
|
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
||||||
710
stable_audio_tools/models/conditioners.py
Normal file
710
stable_audio_tools/models/conditioners.py
Normal file
|
|
@ -0,0 +1,710 @@
|
||||||
|
#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import logging, warnings
|
||||||
|
import string
|
||||||
|
import typing as tp
|
||||||
|
import gc
|
||||||
|
|
||||||
|
from .adp import NumberEmbedder
|
||||||
|
from ..inference.utils import set_audio_channels
|
||||||
|
from .factory import create_pretransform_from_config
|
||||||
|
from .pretransforms import Pretransform
|
||||||
|
from .utils import load_ckpt_state_dict
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers import AutoProcessor, CLIPVisionModelWithProjection
|
||||||
|
import einops
|
||||||
|
from .temptransformer import SA_Transformer
|
||||||
|
from torchvision import transforms
|
||||||
|
import torch
|
||||||
|
import einops
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
|
||||||
|
class Conditioner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
project_out: bool = False
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: tp.Any) -> tp.Any:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class IntConditioner(Conditioner):
|
||||||
|
def __init__(self,
|
||||||
|
output_dim: int,
|
||||||
|
min_val: int=0,
|
||||||
|
max_val: int=512
|
||||||
|
):
|
||||||
|
super().__init__(output_dim, output_dim)
|
||||||
|
|
||||||
|
self.min_val = min_val
|
||||||
|
self.max_val = max_val
|
||||||
|
self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True)
|
||||||
|
|
||||||
|
def forward(self, ints: tp.List[int], device=None) -> tp.Any:
|
||||||
|
|
||||||
|
#self.int_embedder.to(device)
|
||||||
|
|
||||||
|
ints = torch.tensor(ints).to(device)
|
||||||
|
ints = ints.clamp(self.min_val, self.max_val)
|
||||||
|
|
||||||
|
int_embeds = self.int_embedder(ints).unsqueeze(1)
|
||||||
|
|
||||||
|
return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
|
||||||
|
|
||||||
|
class NumberConditioner(Conditioner):
|
||||||
|
'''
|
||||||
|
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
output_dim: int,
|
||||||
|
min_val: float=0,
|
||||||
|
max_val: float=1
|
||||||
|
):
|
||||||
|
super().__init__(output_dim, output_dim)
|
||||||
|
|
||||||
|
self.min_val = min_val
|
||||||
|
self.max_val = max_val
|
||||||
|
|
||||||
|
self.embedder = NumberEmbedder(features=output_dim)
|
||||||
|
|
||||||
|
def forward(self, floats: tp.List[float], device=None) -> tp.Any:
|
||||||
|
|
||||||
|
# Cast the inputs to floats
|
||||||
|
floats = [float(x) for x in floats]
|
||||||
|
|
||||||
|
floats = torch.tensor(floats).to(device)
|
||||||
|
|
||||||
|
floats = floats.clamp(self.min_val, self.max_val)
|
||||||
|
|
||||||
|
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
||||||
|
|
||||||
|
# Cast floats to same type as embedder
|
||||||
|
embedder_dtype = next(self.embedder.parameters()).dtype
|
||||||
|
normalized_floats = normalized_floats.to(embedder_dtype)
|
||||||
|
|
||||||
|
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
||||||
|
|
||||||
|
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
||||||
|
|
||||||
|
class CLAPTextConditioner(Conditioner):
|
||||||
|
def __init__(self,
|
||||||
|
output_dim: int,
|
||||||
|
clap_ckpt_path,
|
||||||
|
use_text_features = False,
|
||||||
|
feature_layer_ix: int = -1,
|
||||||
|
audio_model_type="HTSAT-base",
|
||||||
|
enable_fusion=True,
|
||||||
|
project_out: bool = False,
|
||||||
|
finetune: bool = False):
|
||||||
|
super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out)
|
||||||
|
|
||||||
|
self.use_text_features = use_text_features
|
||||||
|
self.feature_layer_ix = feature_layer_ix
|
||||||
|
self.finetune = finetune
|
||||||
|
|
||||||
|
# Suppress logging from transformers
|
||||||
|
previous_level = logging.root.manager.disable
|
||||||
|
logging.disable(logging.ERROR)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
try:
|
||||||
|
import laion_clap
|
||||||
|
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
|
||||||
|
|
||||||
|
model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
|
||||||
|
|
||||||
|
if self.finetune:
|
||||||
|
self.model = model
|
||||||
|
else:
|
||||||
|
self.__dict__["model"] = model
|
||||||
|
|
||||||
|
state_dict = clap_load_state_dict(clap_ckpt_path)
|
||||||
|
self.model.model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if self.finetune:
|
||||||
|
self.model.model.text_branch.requires_grad_(True)
|
||||||
|
self.model.model.text_branch.train()
|
||||||
|
else:
|
||||||
|
self.model.model.text_branch.requires_grad_(False)
|
||||||
|
self.model.model.text_branch.eval()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
logging.disable(previous_level)
|
||||||
|
|
||||||
|
del self.model.model.audio_branch
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
|
||||||
|
prompt_tokens = self.model.tokenizer(prompts)
|
||||||
|
attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
|
||||||
|
prompt_features = self.model.model.text_branch(
|
||||||
|
input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True
|
||||||
|
)["hidden_states"][layer_ix]
|
||||||
|
|
||||||
|
return prompt_features, attention_mask
|
||||||
|
|
||||||
|
def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
|
if self.use_text_features:
|
||||||
|
if len(texts) == 1:
|
||||||
|
text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device)
|
||||||
|
text_features = text_features[:1, ...]
|
||||||
|
text_attention_mask = text_attention_mask[:1, ...]
|
||||||
|
else:
|
||||||
|
text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device)
|
||||||
|
return [self.proj_out(text_features), text_attention_mask]
|
||||||
|
|
||||||
|
# Fix for CLAP bug when only one text is passed
|
||||||
|
if len(texts) == 1:
|
||||||
|
text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...]
|
||||||
|
else:
|
||||||
|
text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
|
||||||
|
|
||||||
|
text_embedding = text_embedding.unsqueeze(1).to(device)
|
||||||
|
|
||||||
|
return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)]
|
||||||
|
|
||||||
|
class CLAPAudioConditioner(Conditioner):
|
||||||
|
def __init__(self,
|
||||||
|
output_dim: int,
|
||||||
|
clap_ckpt_path,
|
||||||
|
audio_model_type="HTSAT-base",
|
||||||
|
enable_fusion=True,
|
||||||
|
project_out: bool = False):
|
||||||
|
super().__init__(512, output_dim, project_out=project_out)
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# Suppress logging from transformers
|
||||||
|
previous_level = logging.root.manager.disable
|
||||||
|
logging.disable(logging.ERROR)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
try:
|
||||||
|
import laion_clap
|
||||||
|
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
|
||||||
|
|
||||||
|
model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
|
||||||
|
|
||||||
|
if self.finetune:
|
||||||
|
self.model = model
|
||||||
|
else:
|
||||||
|
self.__dict__["model"] = model
|
||||||
|
|
||||||
|
state_dict = clap_load_state_dict(clap_ckpt_path)
|
||||||
|
self.model.model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if self.finetune:
|
||||||
|
self.model.model.audio_branch.requires_grad_(True)
|
||||||
|
self.model.model.audio_branch.train()
|
||||||
|
else:
|
||||||
|
self.model.model.audio_branch.requires_grad_(False)
|
||||||
|
self.model.model.audio_branch.eval()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
logging.disable(previous_level)
|
||||||
|
|
||||||
|
del self.model.model.text_branch
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
|
||||||
|
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
|
if isinstance(audios, list) or isinstance(audios, tuple):
|
||||||
|
audios = torch.cat(audios, dim=0)
|
||||||
|
|
||||||
|
# Convert to mono
|
||||||
|
mono_audios = audios.mean(dim=1)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)
|
||||||
|
|
||||||
|
audio_embedding = audio_embedding.unsqueeze(1).to(device)
|
||||||
|
|
||||||
|
return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)]
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPConditioner(Conditioner):
|
||||||
|
CLIP_MODELS = ["clip-vit-base-patch32"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_dim: int,
|
||||||
|
clip_model_name: str = "clip-vit-base-patch32",
|
||||||
|
video_fps: int = 5,
|
||||||
|
out_features: str = 128,
|
||||||
|
enable_grad: bool = False,
|
||||||
|
in_features: int = 5000,
|
||||||
|
project_out: bool = False,
|
||||||
|
):
|
||||||
|
assert clip_model_name in self.CLIP_MODELS, f"Unknown clip model name: {clip_model_name}"
|
||||||
|
super().__init__(dim = 768, output_dim=output_dim, project_out=project_out)
|
||||||
|
|
||||||
|
sa_depth=4
|
||||||
|
num_heads=16
|
||||||
|
dim_head=64
|
||||||
|
hidden_scale=4
|
||||||
|
duration = 10
|
||||||
|
|
||||||
|
self.clip_model_name=clip_model_name
|
||||||
|
|
||||||
|
if self.clip_model_name=='clip-vit-base-patch32':
|
||||||
|
out_features = 128
|
||||||
|
temporal_dim=768
|
||||||
|
|
||||||
|
self.empty_visual_feat = nn.Parameter(torch.zeros(1, out_features, temporal_dim), requires_grad=True)
|
||||||
|
nn.init.constant_(self.empty_visual_feat, 0)
|
||||||
|
|
||||||
|
in_features = 50*video_fps*duration
|
||||||
|
|
||||||
|
self.visual_encoder_model = CLIPVisionModelWithProjection.from_pretrained('openai/clip-vit-base-patch32')
|
||||||
|
self.proj = nn.Linear(in_features=in_features, out_features=out_features)
|
||||||
|
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
|
||||||
|
self.Temp_transformer = SA_Transformer(temporal_dim, sa_depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.)
|
||||||
|
self.Temp_pos_embedding = nn.Parameter(torch.randn(1, duration*video_fps, temporal_dim))
|
||||||
|
|
||||||
|
clip_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
clip_std = [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
self.preprocess_CLIP = transforms.Compose([
|
||||||
|
transforms.Normalize(mean=clip_mean, std=clip_std)
|
||||||
|
])
|
||||||
|
|
||||||
|
def process_video_with_custom_preprocessing(self, video_tensor):
|
||||||
|
video_tensor = video_tensor / 255.0
|
||||||
|
video_tensor = self.preprocess_CLIP(video_tensor)
|
||||||
|
return video_tensor
|
||||||
|
|
||||||
|
def init_first_from_ckpt(self, path):
|
||||||
|
model = torch.load(path, map_location="cpu")
|
||||||
|
if "state_dict" in list(model.keys()):
|
||||||
|
model = model["state_dict"]
|
||||||
|
# Remove: module prefix
|
||||||
|
new_model = {}
|
||||||
|
for key in model.keys():
|
||||||
|
new_key = key.replace("module.","")
|
||||||
|
new_model[new_key] = model[key]
|
||||||
|
missing, unexpected = self.visual_encoder_model.load_state_dict(new_model, strict=False)
|
||||||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
def forward(self, Video_tensors: tp.List[torch.Tensor], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
visual_encoder_model = self.visual_encoder_model.eval().to(device)
|
||||||
|
proj = self.proj.to(device)
|
||||||
|
|
||||||
|
original_videos = torch.cat(Video_tensors, dim=0).to(device)
|
||||||
|
batch_size, time_length, _, _, _ = original_videos.size()
|
||||||
|
is_zero = torch.all(original_videos == 0, dim=(1,2,3,4))
|
||||||
|
Video_tensors = original_videos
|
||||||
|
Video_tensors = einops.rearrange(Video_tensors, 'b t c h w -> (b t) c h w')
|
||||||
|
|
||||||
|
video_cond_pixel_values = self.process_video_with_custom_preprocessing(video_tensor=Video_tensors.to(device)).to(device)
|
||||||
|
if self.clip_model_name=='clip-vit-base-patch32':
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = visual_encoder_model(pixel_values=video_cond_pixel_values)
|
||||||
|
video_hidden = outputs.last_hidden_state
|
||||||
|
|
||||||
|
video_hidden = einops.rearrange(video_hidden, '(b t) q h -> (b q) t h',b=batch_size,t=time_length)
|
||||||
|
video_hidden += self.Temp_pos_embedding
|
||||||
|
video_hidden = self.Temp_transformer(video_hidden)
|
||||||
|
video_hidden = einops.rearrange(video_hidden, '(b q) t h -> b (t q) h',b=batch_size,t=time_length)
|
||||||
|
|
||||||
|
video_hidden = proj(video_hidden.view(-1, self.in_features))
|
||||||
|
video_hidden = video_hidden.view(batch_size, self.out_features, -1)
|
||||||
|
|
||||||
|
empty_visual_feat = self.empty_visual_feat.expand(batch_size, -1, -1)
|
||||||
|
is_zero_expanded = is_zero.view(batch_size, 1, 1)
|
||||||
|
video_hidden = torch.where(is_zero_expanded, empty_visual_feat, video_hidden)
|
||||||
|
|
||||||
|
return video_hidden, torch.ones(video_hidden.shape[0], 1).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class T5Conditioner(Conditioner):
|
||||||
|
|
||||||
|
T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
||||||
|
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
||||||
|
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
||||||
|
|
||||||
|
T5_MODEL_DIMS = {
|
||||||
|
"t5-small": 512,
|
||||||
|
"t5-base": 768,
|
||||||
|
"t5-large": 1024,
|
||||||
|
"t5-3b": 1024,
|
||||||
|
"t5-11b": 1024,
|
||||||
|
"t5-xl": 2048,
|
||||||
|
"t5-xxl": 4096,
|
||||||
|
"google/flan-t5-small": 512,
|
||||||
|
"google/flan-t5-base": 768,
|
||||||
|
"google/flan-t5-large": 1024,
|
||||||
|
"google/flan-t5-3b": 1024,
|
||||||
|
"google/flan-t5-11b": 1024,
|
||||||
|
"google/flan-t5-xl": 2048,
|
||||||
|
"google/flan-t5-xxl": 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_dim: int,
|
||||||
|
t5_model_name: str = "t5-base",
|
||||||
|
max_length: str = 128,
|
||||||
|
enable_grad: bool = False,
|
||||||
|
project_out: bool = False,
|
||||||
|
):
|
||||||
|
assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
|
||||||
|
super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
|
||||||
|
|
||||||
|
from transformers import T5EncoderModel, AutoTokenizer
|
||||||
|
|
||||||
|
self.max_length = max_length
|
||||||
|
self.enable_grad = enable_grad
|
||||||
|
# Suppress logging from transformers
|
||||||
|
previous_level = logging.root.manager.disable
|
||||||
|
logging.disable(logging.ERROR)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
try:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
|
||||||
|
model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
|
||||||
|
finally:
|
||||||
|
logging.disable(previous_level)
|
||||||
|
|
||||||
|
if self.enable_grad:
|
||||||
|
self.model = model
|
||||||
|
else:
|
||||||
|
self.__dict__["model"] = model
|
||||||
|
|
||||||
|
def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
self.model.to(device)
|
||||||
|
self.proj_out.to(device)
|
||||||
|
|
||||||
|
encoded = self.tokenizer(
|
||||||
|
texts,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = encoded["input_ids"].to(device)
|
||||||
|
attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.float16), torch.set_grad_enabled(self.enable_grad):
|
||||||
|
embeddings = self.model(
|
||||||
|
input_ids=input_ids, attention_mask=attention_mask
|
||||||
|
)["last_hidden_state"]
|
||||||
|
|
||||||
|
embeddings = self.proj_out(embeddings.float())
|
||||||
|
embeddings = embeddings * attention_mask.unsqueeze(-1).float()
|
||||||
|
|
||||||
|
return embeddings, attention_mask
|
||||||
|
|
||||||
|
class PhonemeConditioner(Conditioner):
|
||||||
|
"""
|
||||||
|
A conditioner that turns text into phonemes and embeds them using a lookup table
|
||||||
|
Only works for English text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dim: the dimension of the output embeddings
|
||||||
|
max_length: the maximum number of phonemes to embed
|
||||||
|
project_out: whether to add another linear projection to the output embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_dim: int,
|
||||||
|
max_length: int = 1024,
|
||||||
|
project_out: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(output_dim, output_dim, project_out=project_out)
|
||||||
|
|
||||||
|
from g2p_en import G2p
|
||||||
|
self.max_length = max_length
|
||||||
|
self.g2p = G2p()
|
||||||
|
# Reserving 0 for padding, 1 for ignored
|
||||||
|
self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
|
||||||
|
|
||||||
|
def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
self.phoneme_embedder.to(device)
|
||||||
|
self.proj_out.to(device)
|
||||||
|
|
||||||
|
batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length]
|
||||||
|
phoneme_ignore = [" ", *string.punctuation]
|
||||||
|
# Remove ignored phonemes and cut to max length
|
||||||
|
batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes]
|
||||||
|
|
||||||
|
# Convert to ids
|
||||||
|
phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes]
|
||||||
|
|
||||||
|
#Pad to match longest and make a mask tensor for the padding
|
||||||
|
longest = max([len(ids) for ids in phoneme_ids])
|
||||||
|
phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
|
||||||
|
phoneme_ids = torch.tensor(phoneme_ids).to(device)
|
||||||
|
|
||||||
|
# Convert to embeddings
|
||||||
|
phoneme_embeds = self.phoneme_embedder(phoneme_ids)
|
||||||
|
phoneme_embeds = self.proj_out(phoneme_embeds)
|
||||||
|
|
||||||
|
return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerLUTConditioner(Conditioner):
|
||||||
|
"""
|
||||||
|
A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
|
||||||
|
output_dim: the dimension of the output embeddings
|
||||||
|
max_length: the maximum length of the text to embed
|
||||||
|
project_out: whether to add another linear projection to the output embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
|
||||||
|
output_dim: int,
|
||||||
|
max_length: int = 1024,
|
||||||
|
project_out: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(output_dim, output_dim, project_out=project_out)
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
# Suppress logging from transformers
|
||||||
|
previous_level = logging.root.manager.disable
|
||||||
|
logging.disable(logging.ERROR)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
try:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
finally:
|
||||||
|
logging.disable(previous_level)
|
||||||
|
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
|
||||||
|
|
||||||
|
def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
self.proj_out.to(device)
|
||||||
|
|
||||||
|
encoded = self.tokenizer(
|
||||||
|
texts,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = encoded["input_ids"].to(device)
|
||||||
|
attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
|
||||||
|
|
||||||
|
embeddings = self.token_embedder(input_ids)
|
||||||
|
|
||||||
|
embeddings = self.proj_out(embeddings)
|
||||||
|
|
||||||
|
embeddings = embeddings * attention_mask.unsqueeze(-1).float()
|
||||||
|
|
||||||
|
return embeddings, attention_mask
|
||||||
|
|
||||||
|
class PretransformConditioner(Conditioner):
|
||||||
|
"""
|
||||||
|
A conditioner that uses a pretransform's encoder for conditioning
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretransform: an instantiated pretransform to use for conditioning
|
||||||
|
output_dim: the dimension of the output embeddings
|
||||||
|
"""
|
||||||
|
def __init__(self, pretransform: Pretransform, output_dim: int):
|
||||||
|
super().__init__(pretransform.encoded_channels, output_dim)
|
||||||
|
|
||||||
|
self.pretransform = pretransform
|
||||||
|
|
||||||
|
def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
self.pretransform.to(device)
|
||||||
|
self.proj_out.to(device)
|
||||||
|
|
||||||
|
if isinstance(audio, list) or isinstance(audio, tuple):
|
||||||
|
audio = torch.cat(audio, dim=0)
|
||||||
|
|
||||||
|
# Convert audio to pretransform input channels
|
||||||
|
audio = set_audio_channels(audio, self.pretransform.io_channels)
|
||||||
|
|
||||||
|
latents = self.pretransform.encode(audio)
|
||||||
|
latents = self.proj_out(latents)
|
||||||
|
|
||||||
|
return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioAutoencoderConditioner(Conditioner):
|
||||||
|
"""
|
||||||
|
A conditioner that uses a pretransform's encoder for conditioning
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretransform: an instantiated pretransform to use for conditioning
|
||||||
|
output_dim: the dimension of the output embeddings
|
||||||
|
"""
|
||||||
|
def __init__(self, pretransform: Pretransform, output_dim: int):
|
||||||
|
super().__init__(pretransform.encoded_channels, output_dim)
|
||||||
|
|
||||||
|
self.pretransform = pretransform
|
||||||
|
self.empty_audio_feat = nn.Parameter(torch.zeros(1, 215, self.proj_out.out_features), requires_grad=True)
|
||||||
|
nn.init.constant_(self.empty_audio_feat, 0)
|
||||||
|
|
||||||
|
def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
self.pretransform.to(device)
|
||||||
|
self.proj_out.to(device)
|
||||||
|
|
||||||
|
if isinstance(audio, list) or isinstance(audio, tuple):
|
||||||
|
original_audios = torch.cat(audio, dim=0).to(device)
|
||||||
|
is_zero = torch.all(original_audios == 0, dim=(1,2))
|
||||||
|
audio = original_audios
|
||||||
|
|
||||||
|
# Convert audio to pretransform input channels
|
||||||
|
audio = set_audio_channels(audio, self.pretransform.io_channels)
|
||||||
|
|
||||||
|
latents = self.pretransform.encode(audio)
|
||||||
|
latents = latents.permute(0, 2, 1)
|
||||||
|
latents = self.proj_out(latents)
|
||||||
|
|
||||||
|
empty_audio_feat = self.empty_audio_feat.expand(latents.shape[0], -1, -1)
|
||||||
|
is_zero_expanded = is_zero.view(latents.shape[0], 1, 1)
|
||||||
|
latents = torch.where(is_zero_expanded, empty_audio_feat, latents)
|
||||||
|
return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
|
||||||
|
|
||||||
|
|
||||||
|
class MultiConditioner(nn.Module):
|
||||||
|
"""
|
||||||
|
A module that applies multiple conditioners to an input dictionary based on the keys
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
|
||||||
|
default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
|
||||||
|
"""
|
||||||
|
def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conditioners = nn.ModuleDict(conditioners)
|
||||||
|
self.default_keys = default_keys
|
||||||
|
|
||||||
|
def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]:
|
||||||
|
output = {}
|
||||||
|
|
||||||
|
for key, conditioner in self.conditioners.items():
|
||||||
|
condition_key = key
|
||||||
|
|
||||||
|
conditioner_inputs = []
|
||||||
|
|
||||||
|
for x in batch_metadata:
|
||||||
|
|
||||||
|
if condition_key not in x:
|
||||||
|
if condition_key in self.default_keys:
|
||||||
|
condition_key = self.default_keys[condition_key]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Conditioner key {condition_key} not found in batch metadata")
|
||||||
|
|
||||||
|
if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1:
|
||||||
|
conditioner_input = x[condition_key][0]
|
||||||
|
|
||||||
|
else:
|
||||||
|
conditioner_input = x[condition_key]
|
||||||
|
|
||||||
|
conditioner_inputs.append(conditioner_input)
|
||||||
|
|
||||||
|
output[key] = conditioner(conditioner_inputs, device)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner:
|
||||||
|
"""
|
||||||
|
Create a MultiConditioner from a conditioning config dictionary
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: the conditioning config dictionary
|
||||||
|
device: the device to put the conditioners on
|
||||||
|
"""
|
||||||
|
conditioners = {}
|
||||||
|
cond_dim = config["cond_dim"]
|
||||||
|
|
||||||
|
default_keys = config.get("default_keys", {})
|
||||||
|
|
||||||
|
for conditioner_info in config["configs"]:
|
||||||
|
id = conditioner_info["id"]
|
||||||
|
|
||||||
|
conditioner_type = conditioner_info["type"]
|
||||||
|
|
||||||
|
conditioner_config = {"output_dim": cond_dim}
|
||||||
|
|
||||||
|
conditioner_config.update(conditioner_info["config"])
|
||||||
|
|
||||||
|
if conditioner_type == "t5":
|
||||||
|
conditioners[id] = T5Conditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "clip":
|
||||||
|
conditioners[id] = CLIPConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "clap_text":
|
||||||
|
conditioners[id] = CLAPTextConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "clap_audio":
|
||||||
|
conditioners[id] = CLAPAudioConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "int":
|
||||||
|
conditioners[id] = IntConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "number":
|
||||||
|
conditioners[id] = NumberConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "phoneme":
|
||||||
|
conditioners[id] = PhonemeConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "lut":
|
||||||
|
conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "pretransform":
|
||||||
|
sample_rate = conditioner_config.pop("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
|
||||||
|
|
||||||
|
pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
|
||||||
|
|
||||||
|
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
|
||||||
|
pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
|
||||||
|
|
||||||
|
conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
|
||||||
|
|
||||||
|
elif conditioner_type == "audio_autoencoder":
|
||||||
|
sample_rate = conditioner_config.pop("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
|
||||||
|
|
||||||
|
pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
|
||||||
|
|
||||||
|
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
|
||||||
|
pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
|
||||||
|
|
||||||
|
conditioners[id] = AudioAutoencoderConditioner(pretransform, **conditioner_config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown conditioner type: {conditioner_type}")
|
||||||
|
|
||||||
|
return MultiConditioner(conditioners, default_keys=default_keys)
|
||||||
704
stable_audio_tools/models/diffusion.py
Normal file
704
stable_audio_tools/models/diffusion.py
Normal file
|
|
@ -0,0 +1,704 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from functools import partial
|
||||||
|
import numpy as np
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
|
||||||
|
from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
|
||||||
|
from .dit import DiffusionTransformer
|
||||||
|
from .factory import create_pretransform_from_config
|
||||||
|
from .pretransforms import Pretransform
|
||||||
|
from ..inference.generation import generate_diffusion_cond
|
||||||
|
|
||||||
|
from .adp import UNetCFG1d, UNet1d
|
||||||
|
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
class Profiler:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.ticks = [[time(), None]]
|
||||||
|
|
||||||
|
def tick(self, msg):
|
||||||
|
self.ticks.append([time(), msg])
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
rep = 80 * "=" + "\n"
|
||||||
|
for i in range(1, len(self.ticks)):
|
||||||
|
msg = self.ticks[i][1]
|
||||||
|
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
|
||||||
|
rep += msg + f": {ellapsed*1000:.2f}ms\n"
|
||||||
|
rep += 80 * "=" + "\n\n\n"
|
||||||
|
return rep
|
||||||
|
|
||||||
|
class DiffusionModel(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class DiffusionModelWrapper(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: DiffusionModel,
|
||||||
|
io_channels,
|
||||||
|
sample_size,
|
||||||
|
sample_rate,
|
||||||
|
min_input_length,
|
||||||
|
pretransform: tp.Optional[Pretransform] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.io_channels = io_channels
|
||||||
|
self.sample_size = sample_size
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.min_input_length = min_input_length
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
self.pretransform = pretransform
|
||||||
|
else:
|
||||||
|
self.pretransform = None
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
return self.model(x, t, **kwargs)
|
||||||
|
|
||||||
|
class ConditionedDiffusionModel(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
*args,
|
||||||
|
supports_cross_attention: bool = False,
|
||||||
|
supports_input_concat: bool = False,
|
||||||
|
supports_global_cond: bool = False,
|
||||||
|
supports_prepend_cond: bool = False,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supports_cross_attention = supports_cross_attention
|
||||||
|
self.supports_input_concat = supports_input_concat
|
||||||
|
self.supports_global_cond = supports_global_cond
|
||||||
|
self.supports_prepend_cond = supports_prepend_cond
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cross_attn_cond: torch.Tensor = None,
|
||||||
|
cross_attn_mask: torch.Tensor = None,
|
||||||
|
input_concat_cond: torch.Tensor = None,
|
||||||
|
global_embed: torch.Tensor = None,
|
||||||
|
prepend_cond: torch.Tensor = None,
|
||||||
|
prepend_cond_mask: torch.Tensor = None,
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = False,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
**kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class ConditionedDiffusionModelWrapper(nn.Module):
|
||||||
|
"""
|
||||||
|
A diffusion model that takes in conditioning
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: ConditionedDiffusionModel,
|
||||||
|
conditioner: MultiConditioner,
|
||||||
|
io_channels,
|
||||||
|
sample_rate,
|
||||||
|
min_input_length: int,
|
||||||
|
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
||||||
|
pretransform: tp.Optional[Pretransform] = None,
|
||||||
|
cross_attn_cond_ids: tp.List[str] = [],
|
||||||
|
global_cond_ids: tp.List[str] = [],
|
||||||
|
input_concat_ids: tp.List[str] = [],
|
||||||
|
prepend_cond_ids: tp.List[str] = [],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.conditioner = conditioner
|
||||||
|
self.io_channels = io_channels
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.diffusion_objective = diffusion_objective
|
||||||
|
self.pretransform = pretransform
|
||||||
|
self.cross_attn_cond_ids = cross_attn_cond_ids # ['prompt', 'seconds_start', 'seconds_total']
|
||||||
|
self.global_cond_ids = global_cond_ids # ['seconds_start', 'seconds_total']
|
||||||
|
self.input_concat_ids = input_concat_ids
|
||||||
|
self.prepend_cond_ids = prepend_cond_ids
|
||||||
|
self.min_input_length = min_input_length
|
||||||
|
|
||||||
|
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[torch.Tensor, tp.Any], negative=False):
|
||||||
|
cross_attention_input = None
|
||||||
|
cross_attention_masks = None
|
||||||
|
global_cond = None
|
||||||
|
input_concat_cond = None
|
||||||
|
prepend_cond = None
|
||||||
|
prepend_cond_mask = None
|
||||||
|
|
||||||
|
if len(self.cross_attn_cond_ids) > 0:
|
||||||
|
# Concatenate all cross-attention inputs over the sequence dimension
|
||||||
|
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||||
|
cross_attention_input = []
|
||||||
|
cross_attention_masks = []
|
||||||
|
|
||||||
|
for key in self.cross_attn_cond_ids:
|
||||||
|
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
|
||||||
|
|
||||||
|
# Add sequence dimension if it's not there
|
||||||
|
if len(cross_attn_in.shape) == 2:
|
||||||
|
cross_attn_in = cross_attn_in.unsqueeze(1)
|
||||||
|
cross_attn_mask = cross_attn_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
cross_attention_input.append(cross_attn_in)
|
||||||
|
cross_attention_masks.append(cross_attn_mask)
|
||||||
|
|
||||||
|
cross_attention_input = torch.cat(cross_attention_input, dim=1) # [1, 130, 768] (text feature:128)
|
||||||
|
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
||||||
|
|
||||||
|
if len(self.global_cond_ids) > 0:
|
||||||
|
# Concatenate all global conditioning inputs over the channel dimension
|
||||||
|
# Assumes that the global conditioning inputs are of shape (batch, channels)
|
||||||
|
global_conds = []
|
||||||
|
for key in self.global_cond_ids:
|
||||||
|
|
||||||
|
global_cond_input = conditioning_tensors[key][0]
|
||||||
|
|
||||||
|
global_conds.append(global_cond_input)
|
||||||
|
|
||||||
|
# Concatenate over the channel dimension
|
||||||
|
global_cond = torch.cat(global_conds, dim=-1)
|
||||||
|
|
||||||
|
if len(global_cond.shape) == 3:
|
||||||
|
global_cond = global_cond.squeeze(1)
|
||||||
|
|
||||||
|
if len(self.input_concat_ids) > 0: # False
|
||||||
|
# Concatenate all input concat conditioning inputs over the channel dimension
|
||||||
|
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
|
||||||
|
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
|
||||||
|
|
||||||
|
if len(self.prepend_cond_ids) > 0: # False
|
||||||
|
# Concatenate all prepend conditioning inputs over the sequence dimension
|
||||||
|
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
|
||||||
|
prepend_conds = []
|
||||||
|
prepend_cond_masks = []
|
||||||
|
|
||||||
|
for key in self.prepend_cond_ids:
|
||||||
|
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
|
||||||
|
prepend_conds.append(prepend_cond_input)
|
||||||
|
prepend_cond_masks.append(prepend_cond_mask)
|
||||||
|
|
||||||
|
prepend_cond = torch.cat(prepend_conds, dim=1)
|
||||||
|
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
|
||||||
|
|
||||||
|
if negative: # False
|
||||||
|
return {
|
||||||
|
"negative_cross_attn_cond": cross_attention_input,
|
||||||
|
"negative_cross_attn_mask": cross_attention_masks,
|
||||||
|
"negative_global_cond": global_cond,
|
||||||
|
"negative_input_concat_cond": input_concat_cond
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"cross_attn_cond": cross_attention_input,
|
||||||
|
"cross_attn_mask": cross_attention_masks,
|
||||||
|
"global_cond": global_cond,
|
||||||
|
"input_concat_cond": input_concat_cond,
|
||||||
|
"prepend_cond": prepend_cond,
|
||||||
|
"prepend_cond_mask": prepend_cond_mask
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||||
|
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||||
|
|
||||||
|
def generate(self, *args, **kwargs):
|
||||||
|
return generate_diffusion_cond(self, *args, **kwargs)
|
||||||
|
|
||||||
|
class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
|
||||||
|
|
||||||
|
self.model = UNetCFG1d(*args, **kwargs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_cond=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = False,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
negative_global_cond=None,
|
||||||
|
negative_input_concat_cond=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
**kwargs):
|
||||||
|
p = Profiler()
|
||||||
|
|
||||||
|
p.tick("start")
|
||||||
|
|
||||||
|
channels_list = None
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
channels_list = [input_concat_cond]
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
embedding=cross_attn_cond,
|
||||||
|
embedding_mask=cross_attn_mask,
|
||||||
|
features=global_cond,
|
||||||
|
channels_list=channels_list,
|
||||||
|
embedding_scale=cfg_scale,
|
||||||
|
embedding_mask_proba=cfg_dropout_prob,
|
||||||
|
batch_cfg=batch_cfg,
|
||||||
|
rescale_cfg=rescale_cfg,
|
||||||
|
negative_embedding=negative_cross_attn_cond,
|
||||||
|
negative_embedding_mask=negative_cross_attn_mask,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
p.tick("UNetCFG1D forward")
|
||||||
|
|
||||||
|
#print(f"Profiler: {p}")
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
|
||||||
|
|
||||||
|
self.model = UNet1d(*args, **kwargs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_cond=None,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_mask=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = False,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
negative_global_cond=None,
|
||||||
|
negative_input_concat_cond=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
channels_list = None
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
|
||||||
|
# Interpolate input_concat_cond to the same length as x
|
||||||
|
if input_concat_cond.shape[2] != x.shape[2]:
|
||||||
|
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||||
|
|
||||||
|
channels_list = [input_concat_cond]
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
features=global_cond,
|
||||||
|
channels_list=channels_list,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class UNet1DUncondWrapper(DiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
|
||||||
|
|
||||||
|
self.io_channels = in_channels
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
return self.model(x, t, **kwargs)
|
||||||
|
|
||||||
|
class DAU1DCondWrapper(ConditionedDiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
|
||||||
|
|
||||||
|
self.model = DiffusionAttnUnet1D(*args, **kwargs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
input_concat_cond=None,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_mask=None,
|
||||||
|
global_cond=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = False,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
negative_global_cond=None,
|
||||||
|
negative_input_concat_cond=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
return self.model(x, t, cond = input_concat_cond)
|
||||||
|
|
||||||
|
class DiffusionAttnUnet1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
io_channels = 2,
|
||||||
|
depth=14,
|
||||||
|
n_attn_layers = 6,
|
||||||
|
channels = [128, 128, 256, 256] + [512] * 10,
|
||||||
|
cond_dim = 0,
|
||||||
|
cond_noise_aug = False,
|
||||||
|
kernel_size = 5,
|
||||||
|
learned_resample = False,
|
||||||
|
strides = [2] * 13,
|
||||||
|
conv_bias = True,
|
||||||
|
use_snake = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cond_noise_aug = cond_noise_aug
|
||||||
|
|
||||||
|
self.io_channels = io_channels
|
||||||
|
|
||||||
|
if self.cond_noise_aug:
|
||||||
|
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
||||||
|
|
||||||
|
self.timestep_embed = FourierFeatures(1, 16)
|
||||||
|
|
||||||
|
attn_layer = depth - n_attn_layers
|
||||||
|
|
||||||
|
strides = [1] + strides
|
||||||
|
|
||||||
|
block = nn.Identity()
|
||||||
|
|
||||||
|
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
|
||||||
|
|
||||||
|
for i in range(depth, 0, -1):
|
||||||
|
c = channels[i - 1]
|
||||||
|
stride = strides[i-1]
|
||||||
|
if stride > 2 and not learned_resample:
|
||||||
|
raise ValueError("Must have stride 2 without learned resampling")
|
||||||
|
|
||||||
|
if i > 1:
|
||||||
|
c_prev = channels[i - 2]
|
||||||
|
add_attn = i >= attn_layer and n_attn_layers > 0
|
||||||
|
block = SkipBlock(
|
||||||
|
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
|
||||||
|
conv_block(c_prev, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
block,
|
||||||
|
conv_block(c * 2 if i != depth else c, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
conv_block(c, c, c_prev),
|
||||||
|
SelfAttention1d(c_prev, c_prev //
|
||||||
|
32) if add_attn else nn.Identity(),
|
||||||
|
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cond_embed_dim = 16 if not self.cond_noise_aug else 32
|
||||||
|
block = nn.Sequential(
|
||||||
|
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
block,
|
||||||
|
conv_block(c * 2, c, c),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
conv_block(c, c, io_channels, is_last=True),
|
||||||
|
)
|
||||||
|
self.net = block
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.net.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self, x, t, cond=None, cond_aug_scale=None):
|
||||||
|
|
||||||
|
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
|
||||||
|
|
||||||
|
inputs = [x, timestep_embed]
|
||||||
|
|
||||||
|
if cond is not None:
|
||||||
|
if cond.shape[2] != x.shape[2]:
|
||||||
|
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
|
||||||
|
|
||||||
|
if self.cond_noise_aug:
|
||||||
|
# Get a random number between 0 and 1, uniformly sampled
|
||||||
|
if cond_aug_scale is None:
|
||||||
|
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
|
||||||
|
else:
|
||||||
|
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
|
||||||
|
|
||||||
|
# Add noise to the conditioning signal
|
||||||
|
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
|
||||||
|
|
||||||
|
# Get embedding for noise cond level, reusing timestamp_embed
|
||||||
|
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
|
||||||
|
|
||||||
|
inputs.append(aug_level_embed)
|
||||||
|
|
||||||
|
inputs.append(cond)
|
||||||
|
|
||||||
|
outputs = self.net(torch.cat(inputs, dim=1))
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class DiTWrapper(ConditionedDiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
||||||
|
|
||||||
|
self.model = DiffusionTransformer(*args, **kwargs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_mask=None,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
negative_input_concat_cond=None,
|
||||||
|
global_cond=None,
|
||||||
|
negative_global_cond=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = True,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
scale_phi: float = 0.0,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
||||||
|
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
||||||
|
|
||||||
|
return self.model(
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=cross_attn_cond,
|
||||||
|
cross_attn_cond_mask=cross_attn_mask,
|
||||||
|
negative_cross_attn_cond=negative_cross_attn_cond,
|
||||||
|
negative_cross_attn_mask=negative_cross_attn_mask,
|
||||||
|
input_concat_cond=input_concat_cond,
|
||||||
|
prepend_cond=prepend_cond,
|
||||||
|
prepend_cond_mask=prepend_cond_mask,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
cfg_dropout_prob=cfg_dropout_prob,
|
||||||
|
scale_phi=scale_phi,
|
||||||
|
global_embed=global_cond,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
class DiTUncondWrapper(DiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
|
||||||
|
|
||||||
|
self.io_channels = in_channels
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
return self.model(x, t, **kwargs)
|
||||||
|
|
||||||
|
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
|
||||||
|
diffusion_uncond_config = config["model"]
|
||||||
|
|
||||||
|
model_type = diffusion_uncond_config.get('type', None)
|
||||||
|
|
||||||
|
diffusion_config = diffusion_uncond_config.get('config', {})
|
||||||
|
|
||||||
|
assert model_type is not None, "Must specify model type in config"
|
||||||
|
|
||||||
|
pretransform = diffusion_uncond_config.get("pretransform", None)
|
||||||
|
|
||||||
|
sample_size = config.get("sample_size", None)
|
||||||
|
assert sample_size is not None, "Must specify sample size in config"
|
||||||
|
|
||||||
|
sample_rate = config.get("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "Must specify sample rate in config"
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||||
|
min_input_length = pretransform.downsampling_ratio
|
||||||
|
else:
|
||||||
|
min_input_length = 1
|
||||||
|
|
||||||
|
if model_type == 'DAU1d':
|
||||||
|
|
||||||
|
model = DiffusionAttnUnet1D(
|
||||||
|
**diffusion_config
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == "adp_uncond_1d":
|
||||||
|
|
||||||
|
model = UNet1DUncondWrapper(
|
||||||
|
**diffusion_config
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == "dit":
|
||||||
|
model = DiTUncondWrapper(
|
||||||
|
**diffusion_config
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||||
|
|
||||||
|
return DiffusionModelWrapper(model,
|
||||||
|
io_channels=model.io_channels,
|
||||||
|
sample_size=sample_size,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
pretransform=pretransform,
|
||||||
|
min_input_length=min_input_length)
|
||||||
|
|
||||||
|
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||||
|
|
||||||
|
model_config = config["model"]
|
||||||
|
|
||||||
|
model_type = config["model_type"]
|
||||||
|
|
||||||
|
diffusion_config = model_config.get('diffusion', None)
|
||||||
|
assert diffusion_config is not None, "Must specify diffusion config"
|
||||||
|
|
||||||
|
diffusion_model_type = diffusion_config.get('type', None)
|
||||||
|
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
||||||
|
|
||||||
|
diffusion_model_config = diffusion_config.get('config', None)
|
||||||
|
if diffusion_model_config.get('video_fps', None) is not None:
|
||||||
|
diffusion_model_config.pop('video_fps')
|
||||||
|
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
||||||
|
|
||||||
|
if diffusion_model_type == 'adp_cfg_1d':
|
||||||
|
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
||||||
|
elif diffusion_model_type == 'adp_1d':
|
||||||
|
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||||
|
elif diffusion_model_type == 'dit':
|
||||||
|
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||||
|
|
||||||
|
io_channels = model_config.get('io_channels', None)
|
||||||
|
assert io_channels is not None, "Must specify io_channels in model config"
|
||||||
|
|
||||||
|
sample_rate = config.get('sample_rate', None)
|
||||||
|
assert sample_rate is not None, "Must specify sample_rate in config"
|
||||||
|
|
||||||
|
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
||||||
|
|
||||||
|
conditioning_config = model_config.get('conditioning', None)
|
||||||
|
|
||||||
|
conditioner = None
|
||||||
|
if conditioning_config is not None:
|
||||||
|
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
|
||||||
|
|
||||||
|
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
||||||
|
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
||||||
|
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
||||||
|
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
||||||
|
|
||||||
|
pretransform = model_config.get("pretransform", None)
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||||
|
min_input_length = pretransform.downsampling_ratio
|
||||||
|
else:
|
||||||
|
min_input_length = 1
|
||||||
|
|
||||||
|
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
||||||
|
min_input_length *= np.prod(diffusion_model_config["factors"])
|
||||||
|
elif diffusion_model_type == "dit":
|
||||||
|
min_input_length *= diffusion_model.model.patch_size
|
||||||
|
|
||||||
|
# Get the proper wrapper class
|
||||||
|
|
||||||
|
extra_kwargs = {}
|
||||||
|
|
||||||
|
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
|
||||||
|
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||||
|
|
||||||
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||||
|
|
||||||
|
elif model_type == "diffusion_prior":
|
||||||
|
prior_type = model_config.get("prior_type", None)
|
||||||
|
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
|
||||||
|
|
||||||
|
if prior_type == "mono_stereo":
|
||||||
|
from .diffusion_prior import MonoToStereoDiffusionPrior
|
||||||
|
wrapper_fn = MonoToStereoDiffusionPrior
|
||||||
|
|
||||||
|
return wrapper_fn(
|
||||||
|
diffusion_model,
|
||||||
|
conditioner,
|
||||||
|
min_input_length=min_input_length,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
cross_attn_cond_ids=cross_attention_ids,
|
||||||
|
global_cond_ids=global_cond_ids,
|
||||||
|
input_concat_ids=input_concat_ids,
|
||||||
|
prepend_cond_ids=prepend_cond_ids,
|
||||||
|
pretransform=pretransform,
|
||||||
|
io_channels=io_channels,
|
||||||
|
**extra_kwargs
|
||||||
|
)
|
||||||
546
stable_audio_tools/models/discriminators.py
Normal file
546
stable_audio_tools/models/discriminators.py
Normal file
|
|
@ -0,0 +1,546 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from functools import reduce
|
||||||
|
import typing as tp
|
||||||
|
from einops import rearrange
|
||||||
|
from audiotools import AudioSignal, STFTParams
|
||||||
|
from dac.model.discriminator import WNConv1d, WNConv2d
|
||||||
|
|
||||||
|
def get_hinge_losses(score_real, score_fake):
|
||||||
|
gen_loss = -score_fake.mean()
|
||||||
|
dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean()
|
||||||
|
return dis_loss, gen_loss
|
||||||
|
|
||||||
|
class EncodecDiscriminator(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
from encodec.msstftd import MultiScaleSTFTDiscriminator
|
||||||
|
|
||||||
|
self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
logits, features = self.discriminators(x)
|
||||||
|
return logits, features
|
||||||
|
|
||||||
|
def loss(self, x, y):
|
||||||
|
feature_matching_distance = 0.
|
||||||
|
logits_true, feature_true = self.forward(x)
|
||||||
|
logits_fake, feature_fake = self.forward(y)
|
||||||
|
|
||||||
|
dis_loss = torch.tensor(0.)
|
||||||
|
adv_loss = torch.tensor(0.)
|
||||||
|
|
||||||
|
for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)):
|
||||||
|
|
||||||
|
feature_matching_distance = feature_matching_distance + sum(
|
||||||
|
map(
|
||||||
|
lambda x, y: abs(x - y).mean(),
|
||||||
|
scale_true,
|
||||||
|
scale_fake,
|
||||||
|
)) / len(scale_true)
|
||||||
|
|
||||||
|
_dis, _adv = get_hinge_losses(
|
||||||
|
logits_true[i],
|
||||||
|
logits_fake[i],
|
||||||
|
)
|
||||||
|
|
||||||
|
dis_loss = dis_loss + _dis
|
||||||
|
adv_loss = adv_loss + _adv
|
||||||
|
|
||||||
|
return dis_loss, adv_loss, feature_matching_distance
|
||||||
|
|
||||||
|
# Discriminators from oobleck
|
||||||
|
|
||||||
|
IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]]
|
||||||
|
|
||||||
|
TensorDict = tp.Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
class SharedDiscriminatorConvNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_size: int,
|
||||||
|
convolution: tp.Union[nn.Conv1d, nn.Conv2d],
|
||||||
|
out_size: int = 1,
|
||||||
|
capacity: int = 32,
|
||||||
|
n_layers: int = 4,
|
||||||
|
kernel_size: int = 15,
|
||||||
|
stride: int = 4,
|
||||||
|
activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(),
|
||||||
|
normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
channels = [in_size]
|
||||||
|
channels += list(capacity * 2**np.arange(n_layers))
|
||||||
|
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = n_layers * [stride]
|
||||||
|
|
||||||
|
net = []
|
||||||
|
for i in range(n_layers):
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
pad = kernel_size // 2
|
||||||
|
s = stride[i]
|
||||||
|
else:
|
||||||
|
pad = kernel_size[0] // 2
|
||||||
|
s = (stride[i], 1)
|
||||||
|
|
||||||
|
net.append(
|
||||||
|
normalization(
|
||||||
|
convolution(
|
||||||
|
channels[i],
|
||||||
|
channels[i + 1],
|
||||||
|
kernel_size,
|
||||||
|
stride=s,
|
||||||
|
padding=pad,
|
||||||
|
)))
|
||||||
|
net.append(activation())
|
||||||
|
|
||||||
|
net.append(convolution(channels[-1], out_size, 1))
|
||||||
|
|
||||||
|
self.net = nn.ModuleList(net)
|
||||||
|
|
||||||
|
def forward(self, x) -> IndividualDiscriminatorOut:
|
||||||
|
features = []
|
||||||
|
for layer in self.net:
|
||||||
|
x = layer(x)
|
||||||
|
if isinstance(layer, nn.modules.conv._ConvNd):
|
||||||
|
features.append(x)
|
||||||
|
score = x.reshape(x.shape[0], -1).mean(-1)
|
||||||
|
return score, features
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleDiscriminator(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int,
|
||||||
|
n_scales: int,
|
||||||
|
**conv_kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
for _ in range(n_scales):
|
||||||
|
layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs))
|
||||||
|
self.layers = nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
|
||||||
|
score = 0
|
||||||
|
features = []
|
||||||
|
for layer in self.layers:
|
||||||
|
s, f = layer(x)
|
||||||
|
score = score + s
|
||||||
|
features.extend(f)
|
||||||
|
x = nn.functional.avg_pool1d(x, 2)
|
||||||
|
return score, features
|
||||||
|
|
||||||
|
class MultiPeriodDiscriminator(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int,
|
||||||
|
periods: tp.Sequence[int],
|
||||||
|
**conv_kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
self.periods = periods
|
||||||
|
|
||||||
|
for _ in periods:
|
||||||
|
layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs))
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
|
||||||
|
score = 0
|
||||||
|
features = []
|
||||||
|
for layer, n in zip(self.layers, self.periods):
|
||||||
|
s, f = layer(self.fold(x, n))
|
||||||
|
score = score + s
|
||||||
|
features.extend(f)
|
||||||
|
return score, features
|
||||||
|
|
||||||
|
def fold(self, x: torch.Tensor, n: int) -> torch.Tensor:
|
||||||
|
pad = (n - (x.shape[-1] % n)) % n
|
||||||
|
x = nn.functional.pad(x, (0, pad))
|
||||||
|
return x.reshape(*x.shape[:2], -1, n)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDiscriminator(nn.Module):
|
||||||
|
"""
|
||||||
|
Individual discriminators should take a single tensor as input (NxB C T) and
|
||||||
|
return a tuple composed of a score tensor (NxB) and a Sequence of Features
|
||||||
|
Sequence[NxB C' T'].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, discriminator_list: tp.Sequence[nn.Module],
|
||||||
|
keys: tp.Sequence[str]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.discriminators = nn.ModuleList(discriminator_list)
|
||||||
|
self.keys = keys
|
||||||
|
|
||||||
|
def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict:
|
||||||
|
features = features.chunk(len(self.keys), 0)
|
||||||
|
return {k: features[i] for i, k in enumerate(self.keys)}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def concat_dicts(dict_a, dict_b):
|
||||||
|
out_dict = {}
|
||||||
|
keys = set(list(dict_a.keys()) + list(dict_b.keys()))
|
||||||
|
for k in keys:
|
||||||
|
out_dict[k] = []
|
||||||
|
if k in dict_a:
|
||||||
|
if isinstance(dict_a[k], list):
|
||||||
|
out_dict[k].extend(dict_a[k])
|
||||||
|
else:
|
||||||
|
out_dict[k].append(dict_a[k])
|
||||||
|
if k in dict_b:
|
||||||
|
if isinstance(dict_b[k], list):
|
||||||
|
out_dict[k].extend(dict_b[k])
|
||||||
|
else:
|
||||||
|
out_dict[k].append(dict_b[k])
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sum_dicts(dict_a, dict_b):
|
||||||
|
out_dict = {}
|
||||||
|
keys = set(list(dict_a.keys()) + list(dict_b.keys()))
|
||||||
|
for k in keys:
|
||||||
|
out_dict[k] = 0.
|
||||||
|
if k in dict_a:
|
||||||
|
out_dict[k] = out_dict[k] + dict_a[k]
|
||||||
|
if k in dict_b:
|
||||||
|
out_dict[k] = out_dict[k] + dict_b[k]
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
def forward(self, inputs: TensorDict) -> TensorDict:
|
||||||
|
discriminator_input = torch.cat([inputs[k] for k in self.keys], 0)
|
||||||
|
all_scores = []
|
||||||
|
all_features = []
|
||||||
|
|
||||||
|
for discriminator in self.discriminators:
|
||||||
|
score, features = discriminator(discriminator_input)
|
||||||
|
scores = self.unpack_tensor_to_dict(score)
|
||||||
|
scores = {f"score_{k}": scores[k] for k in scores.keys()}
|
||||||
|
all_scores.append(scores)
|
||||||
|
|
||||||
|
features = map(self.unpack_tensor_to_dict, features)
|
||||||
|
features = reduce(self.concat_dicts, features)
|
||||||
|
features = {f"features_{k}": features[k] for k in features.keys()}
|
||||||
|
all_features.append(features)
|
||||||
|
|
||||||
|
all_scores = reduce(self.sum_dicts, all_scores)
|
||||||
|
all_features = reduce(self.concat_dicts, all_features)
|
||||||
|
|
||||||
|
inputs.update(all_scores)
|
||||||
|
inputs.update(all_features)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
class OobleckDiscriminator(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
multi_scale_discriminator = MultiScaleDiscriminator(
|
||||||
|
in_channels=in_channels,
|
||||||
|
n_scales=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
multi_period_discriminator = MultiPeriodDiscriminator(
|
||||||
|
in_channels=in_channels,
|
||||||
|
periods=[2, 3, 5, 7, 11]
|
||||||
|
)
|
||||||
|
|
||||||
|
# multi_resolution_discriminator = MultiScaleSTFTDiscriminator(
|
||||||
|
# filters=32,
|
||||||
|
# in_channels = in_channels,
|
||||||
|
# out_channels = 1,
|
||||||
|
# n_ffts = [2048, 1024, 512, 256, 128],
|
||||||
|
# hop_lengths = [512, 256, 128, 64, 32],
|
||||||
|
# win_lengths = [2048, 1024, 512, 256, 128]
|
||||||
|
# )
|
||||||
|
|
||||||
|
self.multi_discriminator = MultiDiscriminator(
|
||||||
|
[multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator],
|
||||||
|
["reals", "fakes"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def loss(self, reals, fakes):
|
||||||
|
inputs = {
|
||||||
|
"reals": reals,
|
||||||
|
"fakes": fakes,
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs = self.multi_discriminator(inputs)
|
||||||
|
|
||||||
|
scores_real = inputs["score_reals"]
|
||||||
|
scores_fake = inputs["score_fakes"]
|
||||||
|
|
||||||
|
features_real = inputs["features_reals"]
|
||||||
|
features_fake = inputs["features_fakes"]
|
||||||
|
|
||||||
|
dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake)
|
||||||
|
|
||||||
|
feature_matching_distance = torch.tensor(0.)
|
||||||
|
|
||||||
|
for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)):
|
||||||
|
|
||||||
|
feature_matching_distance = feature_matching_distance + sum(
|
||||||
|
map(
|
||||||
|
lambda real, fake: abs(real - fake).mean(),
|
||||||
|
scale_real,
|
||||||
|
scale_fake,
|
||||||
|
)) / len(scale_real)
|
||||||
|
|
||||||
|
return dis_loss, gen_loss, feature_matching_distance
|
||||||
|
|
||||||
|
|
||||||
|
## Discriminators from Descript Audio Codec repo
|
||||||
|
## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt
|
||||||
|
class MPD(nn.Module):
|
||||||
|
def __init__(self, period, channels=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.period = period
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)),
|
||||||
|
WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
|
||||||
|
WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
|
||||||
|
WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
|
||||||
|
WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = WNConv2d(
|
||||||
|
1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def pad_to_period(self, x):
|
||||||
|
t = x.shape[-1]
|
||||||
|
x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
x = self.pad_to_period(x)
|
||||||
|
x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
|
||||||
|
|
||||||
|
for layer in self.convs:
|
||||||
|
x = layer(x)
|
||||||
|
fmap.append(x)
|
||||||
|
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
|
||||||
|
return fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MSD(nn.Module):
|
||||||
|
def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
WNConv1d(channels, 16, 15, 1, padding=7),
|
||||||
|
WNConv1d(16, 64, 41, 4, groups=4, padding=20),
|
||||||
|
WNConv1d(64, 256, 41, 4, groups=16, padding=20),
|
||||||
|
WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
|
||||||
|
WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
|
||||||
|
WNConv1d(1024, 1024, 5, 1, padding=2),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.rate = rate
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = AudioSignal(x, self.sample_rate)
|
||||||
|
x.resample(self.sample_rate // self.rate)
|
||||||
|
x = x.audio_data
|
||||||
|
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
|
||||||
|
return fmap
|
||||||
|
|
||||||
|
|
||||||
|
BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
|
||||||
|
|
||||||
|
|
||||||
|
class MRD(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_length: int,
|
||||||
|
hop_factor: float = 0.25,
|
||||||
|
sample_rate: int = 44100,
|
||||||
|
bands: list = BANDS,
|
||||||
|
channels: int = 1
|
||||||
|
):
|
||||||
|
"""Complex multi-band spectrogram discriminator.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
window_length : int
|
||||||
|
Window length of STFT.
|
||||||
|
hop_factor : float, optional
|
||||||
|
Hop factor of the STFT, defaults to ``0.25 * window_length``.
|
||||||
|
sample_rate : int, optional
|
||||||
|
Sampling rate of audio in Hz, by default 44100
|
||||||
|
bands : list, optional
|
||||||
|
Bands to run discriminator over.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.window_length = window_length
|
||||||
|
self.hop_factor = hop_factor
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.stft_params = STFTParams(
|
||||||
|
window_length=window_length,
|
||||||
|
hop_length=int(window_length * hop_factor),
|
||||||
|
match_stride=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
n_fft = window_length // 2 + 1
|
||||||
|
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||||
|
self.bands = bands
|
||||||
|
|
||||||
|
ch = 32
|
||||||
|
convs = lambda: nn.ModuleList(
|
||||||
|
[
|
||||||
|
WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
|
||||||
|
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
||||||
|
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
||||||
|
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
||||||
|
WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||||
|
self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
|
||||||
|
|
||||||
|
def spectrogram(self, x):
|
||||||
|
x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
|
||||||
|
x = torch.view_as_real(x.stft())
|
||||||
|
x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels)
|
||||||
|
# Split into bands
|
||||||
|
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
||||||
|
return x_bands
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_bands = self.spectrogram(x)
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
x = []
|
||||||
|
for band, stack in zip(x_bands, self.band_convs):
|
||||||
|
for layer in stack:
|
||||||
|
band = layer(band)
|
||||||
|
fmap.append(band)
|
||||||
|
x.append(band)
|
||||||
|
|
||||||
|
x = torch.cat(x, dim=-1)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
|
||||||
|
return fmap
|
||||||
|
|
||||||
|
|
||||||
|
class DACDiscriminator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int = 1,
|
||||||
|
rates: list = [],
|
||||||
|
periods: list = [2, 3, 5, 7, 11],
|
||||||
|
fft_sizes: list = [2048, 1024, 512],
|
||||||
|
sample_rate: int = 44100,
|
||||||
|
bands: list = BANDS,
|
||||||
|
):
|
||||||
|
"""Discriminator that combines multiple discriminators.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rates : list, optional
|
||||||
|
sampling rates (in Hz) to run MSD at, by default []
|
||||||
|
If empty, MSD is not used.
|
||||||
|
periods : list, optional
|
||||||
|
periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
|
||||||
|
fft_sizes : list, optional
|
||||||
|
Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
|
||||||
|
sample_rate : int, optional
|
||||||
|
Sampling rate of audio in Hz, by default 44100
|
||||||
|
bands : list, optional
|
||||||
|
Bands to run MRD at, by default `BANDS`
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
discs = []
|
||||||
|
discs += [MPD(p, channels=channels) for p in periods]
|
||||||
|
discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates]
|
||||||
|
discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes]
|
||||||
|
self.discriminators = nn.ModuleList(discs)
|
||||||
|
|
||||||
|
def preprocess(self, y):
|
||||||
|
# Remove DC offset
|
||||||
|
y = y - y.mean(dim=-1, keepdims=True)
|
||||||
|
# Peak normalize the volume of input audio
|
||||||
|
y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.preprocess(x)
|
||||||
|
fmaps = [d(x) for d in self.discriminators]
|
||||||
|
return fmaps
|
||||||
|
|
||||||
|
class DACGANLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Computes a discriminator loss, given a discriminator on
|
||||||
|
generated waveforms/spectrograms compared to ground truth
|
||||||
|
waveforms/spectrograms. Computes the loss for both the
|
||||||
|
discriminator and the generator in separate functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **discriminator_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.discriminator = DACDiscriminator(**discriminator_kwargs)
|
||||||
|
|
||||||
|
def forward(self, fake, real):
|
||||||
|
d_fake = self.discriminator(fake)
|
||||||
|
d_real = self.discriminator(real)
|
||||||
|
return d_fake, d_real
|
||||||
|
|
||||||
|
def discriminator_loss(self, fake, real):
|
||||||
|
d_fake, d_real = self.forward(fake.clone().detach(), real)
|
||||||
|
|
||||||
|
loss_d = 0
|
||||||
|
for x_fake, x_real in zip(d_fake, d_real):
|
||||||
|
loss_d += torch.mean(x_fake[-1] ** 2)
|
||||||
|
loss_d += torch.mean((1 - x_real[-1]) ** 2)
|
||||||
|
return loss_d
|
||||||
|
|
||||||
|
def generator_loss(self, fake, real):
|
||||||
|
d_fake, d_real = self.forward(fake, real)
|
||||||
|
|
||||||
|
loss_g = 0
|
||||||
|
for x_fake in d_fake:
|
||||||
|
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
|
||||||
|
|
||||||
|
loss_feature = 0
|
||||||
|
|
||||||
|
for i in range(len(d_fake)):
|
||||||
|
for j in range(len(d_fake[i]) - 1):
|
||||||
|
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
|
||||||
|
return loss_g, loss_feature
|
||||||
|
|
||||||
|
def loss(self, fake, real):
|
||||||
|
gen_loss, feature_distance = self.generator_loss(fake, real)
|
||||||
|
dis_loss = self.discriminator_loss(fake, real)
|
||||||
|
|
||||||
|
return dis_loss, gen_loss, feature_distance
|
||||||
379
stable_audio_tools/models/dit.py
Normal file
379
stable_audio_tools/models/dit.py
Normal file
|
|
@ -0,0 +1,379 @@
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from x_transformers import ContinuousTransformerWrapper, Encoder
|
||||||
|
|
||||||
|
from .blocks import FourierFeatures
|
||||||
|
from .transformer import ContinuousTransformer
|
||||||
|
|
||||||
|
class DiffusionTransformer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
io_channels=32,
|
||||||
|
patch_size=1,
|
||||||
|
embed_dim=768,
|
||||||
|
cond_token_dim=0,
|
||||||
|
project_cond_tokens=True,
|
||||||
|
global_cond_dim=0,
|
||||||
|
project_global_cond=True,
|
||||||
|
input_concat_dim=0,
|
||||||
|
prepend_cond_dim=0,
|
||||||
|
depth=12,
|
||||||
|
num_heads=8,
|
||||||
|
transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
|
||||||
|
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cond_token_dim = cond_token_dim
|
||||||
|
|
||||||
|
# Timestep embeddings
|
||||||
|
timestep_features_dim = 256
|
||||||
|
|
||||||
|
self.timestep_features = FourierFeatures(1, timestep_features_dim)
|
||||||
|
|
||||||
|
self.to_timestep_embed = nn.Sequential(
|
||||||
|
nn.Linear(timestep_features_dim, embed_dim, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
if cond_token_dim > 0:
|
||||||
|
# Conditioning tokens
|
||||||
|
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
||||||
|
self.to_cond_embed = nn.Sequential(
|
||||||
|
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cond_embed_dim = 0
|
||||||
|
|
||||||
|
if global_cond_dim > 0:
|
||||||
|
# Global conditioning
|
||||||
|
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
||||||
|
self.to_global_embed = nn.Sequential(
|
||||||
|
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
if prepend_cond_dim > 0:
|
||||||
|
# Prepend conditioning
|
||||||
|
self.to_prepend_embed = nn.Sequential(
|
||||||
|
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_concat_dim = input_concat_dim
|
||||||
|
|
||||||
|
dim_in = io_channels + self.input_concat_dim
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
# Transformer
|
||||||
|
|
||||||
|
self.transformer_type = transformer_type
|
||||||
|
|
||||||
|
self.global_cond_type = global_cond_type
|
||||||
|
|
||||||
|
if self.transformer_type == "x-transformers":
|
||||||
|
self.transformer = ContinuousTransformerWrapper(
|
||||||
|
dim_in=dim_in * patch_size,
|
||||||
|
dim_out=io_channels * patch_size,
|
||||||
|
max_seq_len=0, #Not relevant without absolute positional embeds
|
||||||
|
attn_layers = Encoder(
|
||||||
|
dim=embed_dim,
|
||||||
|
depth=depth,
|
||||||
|
heads=num_heads,
|
||||||
|
attn_flash = True,
|
||||||
|
cross_attend = cond_token_dim > 0,
|
||||||
|
dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
|
||||||
|
zero_init_branch_output=True,
|
||||||
|
use_abs_pos_emb = False,
|
||||||
|
rotary_pos_emb=True,
|
||||||
|
ff_swish = True,
|
||||||
|
ff_glu = True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.transformer_type == "continuous_transformer":
|
||||||
|
|
||||||
|
global_dim = None
|
||||||
|
|
||||||
|
if self.global_cond_type == "adaLN":
|
||||||
|
# The global conditioning is projected to the embed_dim already at this point
|
||||||
|
global_dim = embed_dim
|
||||||
|
|
||||||
|
self.transformer = ContinuousTransformer(
|
||||||
|
dim=embed_dim,
|
||||||
|
depth=depth,
|
||||||
|
dim_heads=embed_dim // num_heads,
|
||||||
|
dim_in=dim_in * patch_size,
|
||||||
|
dim_out=io_channels * patch_size,
|
||||||
|
cross_attend = cond_token_dim > 0,
|
||||||
|
cond_token_dim = cond_embed_dim,
|
||||||
|
global_cond_dim=global_dim,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
||||||
|
|
||||||
|
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
|
||||||
|
nn.init.zeros_(self.preprocess_conv.weight)
|
||||||
|
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
|
||||||
|
nn.init.zeros_(self.postprocess_conv.weight)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
mask=None,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_cond_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_embed=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
return_info=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
cross_attn_cond = self.to_cond_embed(cross_attn_cond) # MLP endecoder, shape: [1, 130, 768]
|
||||||
|
|
||||||
|
if global_embed is not None:
|
||||||
|
# Project the global conditioning to the embedding dimension
|
||||||
|
global_embed = self.to_global_embed(global_embed)
|
||||||
|
|
||||||
|
prepend_inputs = None
|
||||||
|
prepend_mask = None
|
||||||
|
prepend_length = 0
|
||||||
|
if prepend_cond is not None:
|
||||||
|
# Project the prepend conditioning to the embedding dimension
|
||||||
|
prepend_cond = self.to_prepend_embed(prepend_cond)
|
||||||
|
|
||||||
|
prepend_inputs = prepend_cond
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
prepend_mask = prepend_cond_mask
|
||||||
|
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
|
||||||
|
# Interpolate input_concat_cond to the same length as x
|
||||||
|
if input_concat_cond.shape[2] != x.shape[2]:
|
||||||
|
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||||
|
|
||||||
|
x = torch.cat([x, input_concat_cond], dim=1)
|
||||||
|
|
||||||
|
# Get the batch of timestep embeddings
|
||||||
|
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
|
||||||
|
|
||||||
|
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||||
|
if global_embed is not None:
|
||||||
|
global_embed = global_embed + timestep_embed
|
||||||
|
else:
|
||||||
|
global_embed = timestep_embed
|
||||||
|
|
||||||
|
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
||||||
|
if self.global_cond_type == "prepend": # True
|
||||||
|
if prepend_inputs is None: # True
|
||||||
|
# Prepend inputs are just the global embed, and the mask is all ones
|
||||||
|
prepend_inputs = global_embed.unsqueeze(1) # [1, 1, 1536]
|
||||||
|
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
# Prepend inputs are the prepend conditioning + the global embed
|
||||||
|
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
||||||
|
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
||||||
|
|
||||||
|
prepend_length = prepend_inputs.shape[1] # 1
|
||||||
|
|
||||||
|
x = self.preprocess_conv(x) + x # [1, 64, 1024]
|
||||||
|
|
||||||
|
x = rearrange(x, "b c t -> b t c") # [1, 1024, 64]
|
||||||
|
|
||||||
|
extra_args = {}
|
||||||
|
|
||||||
|
if self.global_cond_type == "adaLN": # 'prepend'
|
||||||
|
extra_args["global_cond"] = global_embed
|
||||||
|
|
||||||
|
if self.patch_size > 1: # self.patch_size==1
|
||||||
|
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
||||||
|
|
||||||
|
if self.transformer_type == "x-transformers":
|
||||||
|
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
|
||||||
|
elif self.transformer_type == "continuous_transformer":
|
||||||
|
|
||||||
|
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
output, info = output
|
||||||
|
elif self.transformer_type == "mm_transformer":
|
||||||
|
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
|
||||||
|
|
||||||
|
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
||||||
|
|
||||||
|
if self.patch_size > 1:
|
||||||
|
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
||||||
|
|
||||||
|
output = self.postprocess_conv(output) + output
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output, info
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_cond_mask=None,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_embed=None,
|
||||||
|
negative_global_embed=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob=0.0,
|
||||||
|
causal=False,
|
||||||
|
scale_phi=0.0,
|
||||||
|
mask=None,
|
||||||
|
return_info=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
|
||||||
|
|
||||||
|
if cross_attn_cond_mask is not None:
|
||||||
|
cross_attn_cond_mask = cross_attn_cond_mask.bool()
|
||||||
|
|
||||||
|
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
|
||||||
|
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
prepend_cond_mask = prepend_cond_mask.bool()
|
||||||
|
|
||||||
|
# CFG dropout
|
||||||
|
if cfg_dropout_prob > 0.0:
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||||
|
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
|
||||||
|
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
|
||||||
|
|
||||||
|
if prepend_cond is not None:
|
||||||
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||||
|
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
|
||||||
|
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
|
||||||
|
|
||||||
|
|
||||||
|
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
|
||||||
|
# Classifier-free guidance
|
||||||
|
# Concatenate conditioned and unconditioned inputs on the batch dimension
|
||||||
|
batch_inputs = torch.cat([x, x], dim=0)
|
||||||
|
batch_timestep = torch.cat([t, t], dim=0)
|
||||||
|
|
||||||
|
if global_embed is not None:
|
||||||
|
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
|
||||||
|
else:
|
||||||
|
batch_global_cond = None
|
||||||
|
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
|
||||||
|
else:
|
||||||
|
batch_input_concat_cond = None
|
||||||
|
|
||||||
|
batch_cond = None
|
||||||
|
batch_cond_masks = None
|
||||||
|
|
||||||
|
# Handle CFG for cross-attention conditioning
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
|
||||||
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||||
|
|
||||||
|
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
|
||||||
|
if negative_cross_attn_cond is not None:
|
||||||
|
|
||||||
|
# If there's a negative cross-attention mask, set the masked tokens to the null embed
|
||||||
|
if negative_cross_attn_mask is not None:
|
||||||
|
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
|
||||||
|
|
||||||
|
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
|
||||||
|
|
||||||
|
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
|
||||||
|
|
||||||
|
if cross_attn_cond_mask is not None:
|
||||||
|
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
|
||||||
|
|
||||||
|
batch_prepend_cond = None
|
||||||
|
batch_prepend_cond_mask = None
|
||||||
|
|
||||||
|
if prepend_cond is not None:
|
||||||
|
|
||||||
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||||
|
|
||||||
|
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
|
||||||
|
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
batch_masks = torch.cat([mask, mask], dim=0)
|
||||||
|
else:
|
||||||
|
batch_masks = None
|
||||||
|
|
||||||
|
batch_output = self._forward(
|
||||||
|
batch_inputs,
|
||||||
|
batch_timestep,
|
||||||
|
cross_attn_cond=batch_cond,
|
||||||
|
cross_attn_cond_mask=batch_cond_masks,
|
||||||
|
mask = batch_masks,
|
||||||
|
input_concat_cond=batch_input_concat_cond,
|
||||||
|
global_embed = batch_global_cond,
|
||||||
|
prepend_cond = batch_prepend_cond,
|
||||||
|
prepend_cond_mask = batch_prepend_cond_mask,
|
||||||
|
return_info = return_info,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
batch_output, info = batch_output
|
||||||
|
|
||||||
|
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
|
||||||
|
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
|
||||||
|
|
||||||
|
# CFG Rescale
|
||||||
|
if scale_phi != 0.0:
|
||||||
|
cond_out_std = cond_output.std(dim=1, keepdim=True)
|
||||||
|
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
|
||||||
|
output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
|
||||||
|
else:
|
||||||
|
output = cfg_output
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output, info
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
else:
|
||||||
|
return self._forward(
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=cross_attn_cond,
|
||||||
|
cross_attn_cond_mask=cross_attn_cond_mask,
|
||||||
|
input_concat_cond=input_concat_cond,
|
||||||
|
global_embed=global_embed,
|
||||||
|
prepend_cond=prepend_cond,
|
||||||
|
prepend_cond_mask=prepend_cond_mask,
|
||||||
|
mask=mask,
|
||||||
|
return_info=return_info,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
153
stable_audio_tools/models/factory.py
Normal file
153
stable_audio_tools/models/factory.py
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
def create_model_from_config(model_config):
|
||||||
|
model_type = model_config.get('model_type', None)
|
||||||
|
|
||||||
|
assert model_type is not None, 'model_type must be specified in model config'
|
||||||
|
|
||||||
|
if model_type == 'autoencoder':
|
||||||
|
from .autoencoders import create_autoencoder_from_config
|
||||||
|
return create_autoencoder_from_config(model_config)
|
||||||
|
elif model_type == 'diffusion_uncond':
|
||||||
|
from .diffusion import create_diffusion_uncond_from_config
|
||||||
|
return create_diffusion_uncond_from_config(model_config)
|
||||||
|
elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
|
||||||
|
from .diffusion import create_diffusion_cond_from_config
|
||||||
|
return create_diffusion_cond_from_config(model_config)
|
||||||
|
elif model_type == 'diffusion_autoencoder':
|
||||||
|
from .autoencoders import create_diffAE_from_config
|
||||||
|
return create_diffAE_from_config(model_config)
|
||||||
|
elif model_type == 'lm':
|
||||||
|
from .lm import create_audio_lm_from_config
|
||||||
|
return create_audio_lm_from_config(model_config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||||
|
|
||||||
|
def create_model_from_config_path(model_config_path):
|
||||||
|
with open(model_config_path) as f:
|
||||||
|
model_config = json.load(f)
|
||||||
|
|
||||||
|
return create_model_from_config(model_config)
|
||||||
|
|
||||||
|
def create_pretransform_from_config(pretransform_config, sample_rate):
|
||||||
|
pretransform_type = pretransform_config.get('type', None)
|
||||||
|
|
||||||
|
assert pretransform_type is not None, 'type must be specified in pretransform config'
|
||||||
|
|
||||||
|
if pretransform_type == 'autoencoder':
|
||||||
|
from .autoencoders import create_autoencoder_from_config
|
||||||
|
from .pretransforms import AutoencoderPretransform
|
||||||
|
|
||||||
|
# Create fake top-level config to pass sample rate to autoencoder constructor
|
||||||
|
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
|
||||||
|
autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
|
||||||
|
autoencoder = create_autoencoder_from_config(autoencoder_config)
|
||||||
|
|
||||||
|
scale = pretransform_config.get("scale", 1.0)
|
||||||
|
model_half = pretransform_config.get("model_half", False)
|
||||||
|
iterate_batch = pretransform_config.get("iterate_batch", False)
|
||||||
|
chunked = pretransform_config.get("chunked", False)
|
||||||
|
|
||||||
|
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
|
||||||
|
elif pretransform_type == 'wavelet':
|
||||||
|
from .pretransforms import WaveletPretransform
|
||||||
|
|
||||||
|
wavelet_config = pretransform_config["config"]
|
||||||
|
channels = wavelet_config["channels"]
|
||||||
|
levels = wavelet_config["levels"]
|
||||||
|
wavelet = wavelet_config["wavelet"]
|
||||||
|
|
||||||
|
pretransform = WaveletPretransform(channels, levels, wavelet)
|
||||||
|
elif pretransform_type == 'pqmf':
|
||||||
|
from .pretransforms import PQMFPretransform
|
||||||
|
pqmf_config = pretransform_config["config"]
|
||||||
|
pretransform = PQMFPretransform(**pqmf_config)
|
||||||
|
elif pretransform_type == 'dac_pretrained':
|
||||||
|
from .pretransforms import PretrainedDACPretransform
|
||||||
|
pretrained_dac_config = pretransform_config["config"]
|
||||||
|
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
|
||||||
|
elif pretransform_type == "audiocraft_pretrained":
|
||||||
|
from .pretransforms import AudiocraftCompressionPretransform
|
||||||
|
|
||||||
|
audiocraft_config = pretransform_config["config"]
|
||||||
|
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
|
||||||
|
|
||||||
|
enable_grad = pretransform_config.get('enable_grad', False)
|
||||||
|
pretransform.enable_grad = enable_grad
|
||||||
|
|
||||||
|
pretransform.eval().requires_grad_(pretransform.enable_grad)
|
||||||
|
|
||||||
|
return pretransform
|
||||||
|
|
||||||
|
def create_bottleneck_from_config(bottleneck_config):
|
||||||
|
bottleneck_type = bottleneck_config.get('type', None)
|
||||||
|
|
||||||
|
assert bottleneck_type is not None, 'type must be specified in bottleneck config'
|
||||||
|
|
||||||
|
if bottleneck_type == 'tanh':
|
||||||
|
from .bottleneck import TanhBottleneck
|
||||||
|
bottleneck = TanhBottleneck()
|
||||||
|
elif bottleneck_type == 'vae':
|
||||||
|
from .bottleneck import VAEBottleneck
|
||||||
|
bottleneck = VAEBottleneck()
|
||||||
|
elif bottleneck_type == 'rvq':
|
||||||
|
from .bottleneck import RVQBottleneck
|
||||||
|
|
||||||
|
quantizer_params = {
|
||||||
|
"dim": 128,
|
||||||
|
"codebook_size": 1024,
|
||||||
|
"num_quantizers": 8,
|
||||||
|
"decay": 0.99,
|
||||||
|
"kmeans_init": True,
|
||||||
|
"kmeans_iters": 50,
|
||||||
|
"threshold_ema_dead_code": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
quantizer_params.update(bottleneck_config["config"])
|
||||||
|
|
||||||
|
bottleneck = RVQBottleneck(**quantizer_params)
|
||||||
|
elif bottleneck_type == "dac_rvq":
|
||||||
|
from .bottleneck import DACRVQBottleneck
|
||||||
|
|
||||||
|
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
|
||||||
|
|
||||||
|
elif bottleneck_type == 'rvq_vae':
|
||||||
|
from .bottleneck import RVQVAEBottleneck
|
||||||
|
|
||||||
|
quantizer_params = {
|
||||||
|
"dim": 128,
|
||||||
|
"codebook_size": 1024,
|
||||||
|
"num_quantizers": 8,
|
||||||
|
"decay": 0.99,
|
||||||
|
"kmeans_init": True,
|
||||||
|
"kmeans_iters": 50,
|
||||||
|
"threshold_ema_dead_code": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
quantizer_params.update(bottleneck_config["config"])
|
||||||
|
|
||||||
|
bottleneck = RVQVAEBottleneck(**quantizer_params)
|
||||||
|
|
||||||
|
elif bottleneck_type == 'dac_rvq_vae':
|
||||||
|
from .bottleneck import DACRVQVAEBottleneck
|
||||||
|
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
|
||||||
|
elif bottleneck_type == 'l2_norm':
|
||||||
|
from .bottleneck import L2Bottleneck
|
||||||
|
bottleneck = L2Bottleneck()
|
||||||
|
elif bottleneck_type == "wasserstein":
|
||||||
|
from .bottleneck import WassersteinBottleneck
|
||||||
|
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
|
||||||
|
elif bottleneck_type == "fsq":
|
||||||
|
from .bottleneck import FSQBottleneck
|
||||||
|
bottleneck = FSQBottleneck(**bottleneck_config["config"])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
|
||||||
|
|
||||||
|
requires_grad = bottleneck_config.get('requires_grad', True)
|
||||||
|
if not requires_grad:
|
||||||
|
for param in bottleneck.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
return bottleneck
|
||||||
542
stable_audio_tools/models/lm.py
Normal file
542
stable_audio_tools/models/lm.py
Normal file
|
|
@ -0,0 +1,542 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
from tqdm.auto import trange
|
||||||
|
import typing as tp
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
|
||||||
|
from .factory import create_pretransform_from_config
|
||||||
|
from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone
|
||||||
|
from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform
|
||||||
|
from .utils import multinomial, sample_top_k, sample_top_p
|
||||||
|
|
||||||
|
from .codebook_patterns import (
|
||||||
|
CodebooksPatternProvider,
|
||||||
|
DelayedPatternProvider,
|
||||||
|
MusicLMPattern,
|
||||||
|
ParallelPatternProvider,
|
||||||
|
UnrolledPatternProvider
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license
|
||||||
|
# License can be found in LICENSES/LICENSE_META.txt
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMOutput:
|
||||||
|
# The logits are already re-aligned with the input codes
|
||||||
|
# hence no extra shift is required, e.g. when computing CE
|
||||||
|
logits: torch.Tensor # [B, K, T, card]
|
||||||
|
mask: torch.Tensor # [B, K, T]
|
||||||
|
|
||||||
|
# Wrapper for a multi-codebook language model
|
||||||
|
# Handles patterns and quantizer heads
|
||||||
|
class AudioLanguageModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pattern_provider: CodebooksPatternProvider,
|
||||||
|
backbone: AudioLMBackbone,
|
||||||
|
num_quantizers: int,
|
||||||
|
codebook_size: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pattern_provider = pattern_provider
|
||||||
|
self.backbone = backbone
|
||||||
|
self.num_quantizers = num_quantizers
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
|
||||||
|
self.masked_token_id = codebook_size
|
||||||
|
|
||||||
|
# Per-quantizer embedders
|
||||||
|
# Add one for the mask embed
|
||||||
|
self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)])
|
||||||
|
|
||||||
|
# Per-quantizer output heads
|
||||||
|
self.quantizer_heads = nn.ModuleList([
|
||||||
|
nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
sequence: torch.Tensor, #[batch, seq_len,
|
||||||
|
prepend_cond=None, #[batch, seq, channels]
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
cross_attn_cond=None, #[batch, seq, channels],
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
|
||||||
|
batch, num_quantizers, seq_len = sequence.shape
|
||||||
|
|
||||||
|
assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model"
|
||||||
|
|
||||||
|
backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim]
|
||||||
|
|
||||||
|
dtype = next(self.parameters()).dtype
|
||||||
|
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
cross_attn_cond = cross_attn_cond.to(dtype)
|
||||||
|
|
||||||
|
if prepend_cond is not None:
|
||||||
|
prepend_cond = prepend_cond.to(dtype)
|
||||||
|
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
prepend_cond_mask = prepend_cond_mask.to(dtype)
|
||||||
|
|
||||||
|
backbone_input = backbone_input.to(dtype)
|
||||||
|
|
||||||
|
output = self.backbone(
|
||||||
|
backbone_input,
|
||||||
|
cross_attn_cond=cross_attn_cond,
|
||||||
|
prepend_cond=prepend_cond,
|
||||||
|
prepend_cond_mask=prepend_cond_mask,
|
||||||
|
**kwargs
|
||||||
|
) # [batch, seq_len, embed_dim]
|
||||||
|
|
||||||
|
# Run output through quantizer heads
|
||||||
|
logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size]
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
codes, #[batch, num_quantizers, seq_len]
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning
|
||||||
|
Handles translation between input sequence and pattern-shifted sequence
|
||||||
|
Only used during training
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch, _, seq_len = codes.shape
|
||||||
|
|
||||||
|
pattern = self.pattern_provider.get_pattern(seq_len)
|
||||||
|
|
||||||
|
# Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps
|
||||||
|
shifted_codes, _, _ = pattern.build_pattern_sequence(
|
||||||
|
codes,
|
||||||
|
self.masked_token_id,
|
||||||
|
keep_only_valid_steps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size]
|
||||||
|
logits = self(shifted_codes, **kwargs)
|
||||||
|
|
||||||
|
# Rearrange logits to prepare to revert pattern
|
||||||
|
logits = rearrange(logits, "b n s c -> b c n s")
|
||||||
|
|
||||||
|
# Revert sequence logits back to original sequence length, removing masked steps
|
||||||
|
logits, _, logits_mask = pattern.revert_pattern_logits(
|
||||||
|
logits, float('nan'), keep_only_valid_steps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = rearrange(logits, "b c n t -> b n t c")
|
||||||
|
|
||||||
|
logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len]
|
||||||
|
|
||||||
|
return LMOutput(logits=logits, mask=logits_mask)
|
||||||
|
|
||||||
|
# Conditioning and generation wrapper for a multi-codebook language model
|
||||||
|
# Handles conditioning, CFG, generation, and encoding/decoding
|
||||||
|
class AudioLanguageModelWrapper(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pretransform: Pretransform,
|
||||||
|
lm: AudioLanguageModel,
|
||||||
|
sample_rate: int,
|
||||||
|
min_input_length: int,
|
||||||
|
conditioner: MultiConditioner = None,
|
||||||
|
cross_attn_cond_ids: tp.List[str] = [],
|
||||||
|
prepend_cond_ids: tp.List[str] = [],
|
||||||
|
global_cond_ids: tp.List[str] = []
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert pretransform.is_discrete, "Pretransform must be discrete"
|
||||||
|
self.pretransform = pretransform
|
||||||
|
|
||||||
|
self.pretransform.requires_grad_(False)
|
||||||
|
self.pretransform.eval()
|
||||||
|
|
||||||
|
if isinstance(self.pretransform, AutoencoderPretransform):
|
||||||
|
self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers
|
||||||
|
self.codebook_size = self.pretransform.model.bottleneck.codebook_size
|
||||||
|
elif isinstance(self.pretransform, PretrainedDACPretransform):
|
||||||
|
self.num_quantizers = self.pretransform.model.num_quantizers
|
||||||
|
self.codebook_size = self.pretransform.model.codebook_size
|
||||||
|
elif isinstance(self.pretransform, AudiocraftCompressionPretransform):
|
||||||
|
self.num_quantizers = self.pretransform.num_quantizers
|
||||||
|
self.codebook_size = self.pretransform.codebook_size
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}")
|
||||||
|
|
||||||
|
self.conditioner = conditioner
|
||||||
|
|
||||||
|
self.lm = lm
|
||||||
|
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.min_input_length = min_input_length
|
||||||
|
|
||||||
|
self.cross_attn_cond_ids = cross_attn_cond_ids
|
||||||
|
self.prepend_cond_ids = prepend_cond_ids
|
||||||
|
self.global_cond_ids = global_cond_ids
|
||||||
|
|
||||||
|
def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False):
|
||||||
|
cross_attention_input = None
|
||||||
|
prepend_cond = None
|
||||||
|
prepend_cond_mask = None
|
||||||
|
global_cond = None
|
||||||
|
|
||||||
|
if len(self.cross_attn_cond_ids) > 0:
|
||||||
|
# Concatenate all cross-attention inputs over the sequence dimension
|
||||||
|
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||||
|
cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1)
|
||||||
|
|
||||||
|
if len(self.prepend_cond_ids) > 0:
|
||||||
|
# Concatenate all prepend conditioning inputs over the sequence dimension
|
||||||
|
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
|
||||||
|
prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1)
|
||||||
|
prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1)
|
||||||
|
|
||||||
|
if len(self.global_cond_ids) > 0:
|
||||||
|
# Concatenate all global conditioning inputs over the channel dimension
|
||||||
|
# Assumes that the global conditioning inputs are of shape (batch, channels)
|
||||||
|
global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1)
|
||||||
|
if len(global_cond.shape) == 3:
|
||||||
|
global_cond = global_cond.squeeze(1)
|
||||||
|
|
||||||
|
if negative:
|
||||||
|
return {
|
||||||
|
"negative_cross_attn_cond": cross_attention_input,
|
||||||
|
"negative_prepend_cond": prepend_cond,
|
||||||
|
"negative_prepend_cond_mask": prepend_cond_mask,
|
||||||
|
"negative_global_cond": global_cond
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"cross_attn_cond": cross_attention_input,
|
||||||
|
"prepend_cond": prepend_cond,
|
||||||
|
"prepend_cond_mask": prepend_cond_mask,
|
||||||
|
"global_cond": global_cond
|
||||||
|
}
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
codes,
|
||||||
|
condition_tensors=None,
|
||||||
|
cfg_dropout_prob=0.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute logits for a batch of codes, and translates from conditioning inputs to model inputs
|
||||||
|
Handles CFG dropout
|
||||||
|
"""
|
||||||
|
|
||||||
|
if condition_tensors is None:
|
||||||
|
condition_tensors = {}
|
||||||
|
|
||||||
|
conditioning_inputs = self.get_conditioning_inputs(condition_tensors)
|
||||||
|
|
||||||
|
cross_attn_cond = conditioning_inputs["cross_attn_cond"]
|
||||||
|
prepend_cond = conditioning_inputs["prepend_cond"]
|
||||||
|
prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
|
||||||
|
global_cond = conditioning_inputs["global_cond"]
|
||||||
|
|
||||||
|
if cfg_dropout_prob > 0.0:
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||||
|
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
|
||||||
|
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
|
||||||
|
|
||||||
|
if prepend_cond is not None:
|
||||||
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||||
|
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
|
||||||
|
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
|
||||||
|
|
||||||
|
if global_cond is not None:
|
||||||
|
null_embed = torch.zeros_like(global_cond, device=global_cond.device)
|
||||||
|
dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool)
|
||||||
|
global_cond = torch.where(dropout_mask, null_embed, global_cond)
|
||||||
|
|
||||||
|
return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
|
||||||
|
|
||||||
|
def _sample_next_token(
|
||||||
|
self,
|
||||||
|
sequence, #[batch, num_quantizers, seq_len]
|
||||||
|
conditioning_tensors=None,
|
||||||
|
cross_attn_use_cfg=True,
|
||||||
|
prepend_use_cfg=True,
|
||||||
|
global_use_cfg=True,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
top_k=250,
|
||||||
|
top_p=0.0,
|
||||||
|
temp=1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs
|
||||||
|
Handles CFG inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
if conditioning_tensors is None:
|
||||||
|
conditioning_tensors = {}
|
||||||
|
|
||||||
|
conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors)
|
||||||
|
|
||||||
|
cross_attn_cond = conditioning_inputs["cross_attn_cond"]
|
||||||
|
prepend_cond = conditioning_inputs["prepend_cond"]
|
||||||
|
prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
|
||||||
|
global_cond = conditioning_inputs["global_cond"]
|
||||||
|
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
|
||||||
|
# Batch size is doubled to account for negative samples
|
||||||
|
sequence = torch.cat([sequence, sequence], dim=0)
|
||||||
|
|
||||||
|
if cross_attn_cond is not None and cross_attn_use_cfg:
|
||||||
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||||
|
|
||||||
|
cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
|
||||||
|
|
||||||
|
if prepend_cond is not None and prepend_use_cfg:
|
||||||
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||||
|
|
||||||
|
prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
|
||||||
|
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
|
||||||
|
|
||||||
|
if global_cond is not None and global_use_cfg:
|
||||||
|
null_embed = torch.zeros_like(global_cond, device=global_cond.device)
|
||||||
|
|
||||||
|
global_cond = torch.cat([global_cond, null_embed], dim=0)
|
||||||
|
|
||||||
|
logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
|
||||||
|
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
cond_logits, uncond_logits = logits.chunk(2, dim=0)
|
||||||
|
|
||||||
|
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
||||||
|
|
||||||
|
logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len]
|
||||||
|
|
||||||
|
# Grab the logits for the last step
|
||||||
|
logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size]
|
||||||
|
|
||||||
|
# Apply top-k or top-p sampling
|
||||||
|
|
||||||
|
if temp > 0:
|
||||||
|
probs = torch.softmax(logits / temp, dim=-1)
|
||||||
|
|
||||||
|
if top_p > 0.0:
|
||||||
|
next_token = sample_top_p(probs, p=top_p)
|
||||||
|
elif top_k > 0:
|
||||||
|
next_token = sample_top_k(probs, k=top_k)
|
||||||
|
else:
|
||||||
|
next_token = multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1]
|
||||||
|
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
max_gen_len: int = 256,
|
||||||
|
batch_size: tp.Optional[int] = None,
|
||||||
|
init_data: tp.Optional[torch.Tensor] = None,
|
||||||
|
conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
|
||||||
|
conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None,
|
||||||
|
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
||||||
|
use_cache: bool = True,
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
|
||||||
|
if conditioning_tensors is None and conditioning is not None:
|
||||||
|
# Convert conditioning inputs to conditioning tensors
|
||||||
|
conditioning_tensors = self.conditioner(conditioning, device)
|
||||||
|
|
||||||
|
# Check that batch size is consistent across inputs
|
||||||
|
possible_batch_sizes = []
|
||||||
|
|
||||||
|
if batch_size is not None:
|
||||||
|
possible_batch_sizes.append(batch_size)
|
||||||
|
elif init_data is not None:
|
||||||
|
possible_batch_sizes.append(init_data.shape[0])
|
||||||
|
elif conditioning_tensors is not None:
|
||||||
|
# Assume that the first conditioning tensor has the batch dimension
|
||||||
|
possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0])
|
||||||
|
else:
|
||||||
|
possible_batch_sizes.append(1)
|
||||||
|
|
||||||
|
assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs"
|
||||||
|
|
||||||
|
batch_size = possible_batch_sizes[0]
|
||||||
|
|
||||||
|
if init_data is None:
|
||||||
|
# Initialize with zeros
|
||||||
|
assert batch_size > 0
|
||||||
|
init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
batch_size, num_quantizers, seq_len = init_data.shape
|
||||||
|
|
||||||
|
start_offset = seq_len
|
||||||
|
assert start_offset < max_gen_len, "init data longer than max gen length"
|
||||||
|
|
||||||
|
pattern = self.lm.pattern_provider.get_pattern(max_gen_len)
|
||||||
|
|
||||||
|
unknown_token = -1
|
||||||
|
|
||||||
|
# Initialize the generated codes with the init data, padded with unknown tokens
|
||||||
|
gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long)
|
||||||
|
gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len]
|
||||||
|
|
||||||
|
gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len]
|
||||||
|
|
||||||
|
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
||||||
|
assert start_offset_sequence is not None
|
||||||
|
|
||||||
|
# Generation
|
||||||
|
prev_offset = 0
|
||||||
|
gen_sequence_len = gen_sequence.shape[-1]
|
||||||
|
|
||||||
|
# Reset generation cache
|
||||||
|
if use_cache and self.lm.backbone.use_generation_cache:
|
||||||
|
self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2)
|
||||||
|
|
||||||
|
for offset in trange(start_offset_sequence, gen_sequence_len):
|
||||||
|
|
||||||
|
# Get the full sequence up to the current offset
|
||||||
|
curr_sequence = gen_sequence[..., prev_offset:offset]
|
||||||
|
|
||||||
|
next_token = self._sample_next_token(
|
||||||
|
curr_sequence,
|
||||||
|
conditioning_tensors=conditioning_tensors,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1)
|
||||||
|
next_token[~valid_mask] = self.lm.masked_token_id
|
||||||
|
|
||||||
|
# Update the generated sequence with the next token
|
||||||
|
gen_sequence[..., offset:offset+1] = torch.where(
|
||||||
|
gen_sequence[..., offset:offset+1] == unknown_token,
|
||||||
|
next_token,
|
||||||
|
gen_sequence[..., offset:offset+1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cache and self.lm.backbone.use_generation_cache:
|
||||||
|
# Only update the offset if caching is being used
|
||||||
|
prev_offset = offset
|
||||||
|
|
||||||
|
self.lm.backbone.update_generation_cache(offset)
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
# Callback to report progress
|
||||||
|
# Pass in the offset relative to the start of the sequence, and the length of the current sequence
|
||||||
|
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
||||||
|
|
||||||
|
assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence"
|
||||||
|
|
||||||
|
out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
||||||
|
|
||||||
|
# sanity checks over the returned codes and corresponding masks
|
||||||
|
assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
||||||
|
assert (out_mask[..., :max_gen_len] == 1).all()
|
||||||
|
|
||||||
|
#out_codes = out_codes[..., 0:max_gen_len]
|
||||||
|
|
||||||
|
return out_codes
|
||||||
|
|
||||||
|
|
||||||
|
def generate_audio(
|
||||||
|
self,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate audio from a batch of codes
|
||||||
|
"""
|
||||||
|
|
||||||
|
codes = self.generate(**kwargs)
|
||||||
|
|
||||||
|
audio = self.pretransform.decode_tokens(codes)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def create_audio_lm_from_config(config):
|
||||||
|
model_config = config.get('model', None)
|
||||||
|
assert model_config is not None, 'model config must be specified in config'
|
||||||
|
|
||||||
|
sample_rate = config.get('sample_rate', None)
|
||||||
|
assert sample_rate is not None, "Must specify sample_rate in config"
|
||||||
|
|
||||||
|
lm_config = model_config.get('lm', None)
|
||||||
|
assert lm_config is not None, 'lm config must be specified in model config'
|
||||||
|
|
||||||
|
codebook_pattern = lm_config.get("codebook_pattern", "delay")
|
||||||
|
|
||||||
|
pattern_providers = {
|
||||||
|
'parallel': ParallelPatternProvider,
|
||||||
|
'delay': DelayedPatternProvider,
|
||||||
|
'unroll': UnrolledPatternProvider,
|
||||||
|
'musiclm': MusicLMPattern,
|
||||||
|
}
|
||||||
|
|
||||||
|
pretransform_config = model_config.get("pretransform", None)
|
||||||
|
|
||||||
|
pretransform = create_pretransform_from_config(pretransform_config, sample_rate)
|
||||||
|
|
||||||
|
assert pretransform.is_discrete, "Pretransform must be discrete"
|
||||||
|
|
||||||
|
min_input_length = pretransform.downsampling_ratio
|
||||||
|
|
||||||
|
pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers)
|
||||||
|
|
||||||
|
conditioning_config = model_config.get('conditioning', None)
|
||||||
|
|
||||||
|
conditioner = None
|
||||||
|
if conditioning_config is not None:
|
||||||
|
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
|
||||||
|
|
||||||
|
cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', [])
|
||||||
|
prepend_cond_ids = lm_config.get('prepend_cond_ids', [])
|
||||||
|
global_cond_ids = lm_config.get('global_cond_ids', [])
|
||||||
|
|
||||||
|
lm_type = lm_config.get("type", None)
|
||||||
|
lm_model_config = lm_config.get("config", None)
|
||||||
|
|
||||||
|
assert lm_type is not None, "Must specify lm type in lm config"
|
||||||
|
assert lm_model_config is not None, "Must specify lm model config in lm config"
|
||||||
|
|
||||||
|
if lm_type == "x-transformers":
|
||||||
|
backbone = XTransformersAudioLMBackbone(**lm_model_config)
|
||||||
|
elif lm_type == "continuous_transformer":
|
||||||
|
backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unrecognized lm type {lm_type}")
|
||||||
|
|
||||||
|
lm = AudioLanguageModel(
|
||||||
|
pattern_provider=pattern_provider,
|
||||||
|
backbone=backbone,
|
||||||
|
num_quantizers=pretransform.num_quantizers,
|
||||||
|
codebook_size=pretransform.codebook_size
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AudioLanguageModelWrapper(
|
||||||
|
pretransform=pretransform,
|
||||||
|
lm=lm,
|
||||||
|
conditioner=conditioner,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
min_input_length=min_input_length,
|
||||||
|
cross_attn_cond_ids=cross_attn_cond_ids,
|
||||||
|
prepend_cond_ids=prepend_cond_ids,
|
||||||
|
global_cond_ids=global_cond_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
278
stable_audio_tools/models/local_attention.py
Normal file
278
stable_audio_tools/models/local_attention.py
Normal file
|
|
@ -0,0 +1,278 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .blocks import AdaRMSNorm
|
||||||
|
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
|
||||||
|
|
||||||
|
def checkpoint(function, *args, **kwargs):
|
||||||
|
kwargs.setdefault("use_reentrant", False)
|
||||||
|
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||||
|
|
||||||
|
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
|
||||||
|
class ContinuousLocalTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
dim_in = None,
|
||||||
|
dim_out = None,
|
||||||
|
causal = False,
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
heads = 8,
|
||||||
|
ff_mult = 2,
|
||||||
|
cond_dim = 0,
|
||||||
|
cross_attn_cond_dim = 0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim_head = dim//heads
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
|
||||||
|
|
||||||
|
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
|
||||||
|
|
||||||
|
self.local_attn_window_size = local_attn_window_size
|
||||||
|
|
||||||
|
self.cond_dim = cond_dim
|
||||||
|
|
||||||
|
self.cross_attn_cond_dim = cross_attn_cond_dim
|
||||||
|
|
||||||
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
|
||||||
|
|
||||||
|
for _ in range(depth):
|
||||||
|
|
||||||
|
self.layers.append(nn.ModuleList([
|
||||||
|
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||||
|
Attention(
|
||||||
|
dim=dim,
|
||||||
|
dim_heads=dim_head,
|
||||||
|
causal=causal,
|
||||||
|
zero_init_output=True,
|
||||||
|
natten_kernel_size=local_attn_window_size,
|
||||||
|
),
|
||||||
|
Attention(
|
||||||
|
dim=dim,
|
||||||
|
dim_heads=dim_head,
|
||||||
|
dim_context = cross_attn_cond_dim,
|
||||||
|
zero_init_output=True
|
||||||
|
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
|
||||||
|
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||||
|
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
|
||||||
|
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
|
||||||
|
if prepend_cond is not None:
|
||||||
|
x = torch.cat([prepend_cond, x], dim=1)
|
||||||
|
|
||||||
|
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||||
|
|
||||||
|
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if cond is not None:
|
||||||
|
x = checkpoint(attn_norm, x, cond)
|
||||||
|
else:
|
||||||
|
x = checkpoint(attn_norm, x)
|
||||||
|
|
||||||
|
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
|
||||||
|
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
if cond is not None:
|
||||||
|
x = checkpoint(ff_norm, x, cond)
|
||||||
|
else:
|
||||||
|
x = checkpoint(ff_norm, x)
|
||||||
|
|
||||||
|
x = checkpoint(ff, x) + residual
|
||||||
|
|
||||||
|
return checkpoint(self.project_out, x)
|
||||||
|
|
||||||
|
class TransformerDownsampleBlock1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
embed_dim = 768,
|
||||||
|
depth = 3,
|
||||||
|
heads = 12,
|
||||||
|
downsample_ratio = 2,
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
|
||||||
|
self.transformer = ContinuousLocalTransformer(
|
||||||
|
dim=embed_dim,
|
||||||
|
depth=depth,
|
||||||
|
heads=heads,
|
||||||
|
local_attn_window_size=local_attn_window_size,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
|
||||||
|
# Compute
|
||||||
|
x = self.transformer(x)
|
||||||
|
|
||||||
|
# Trade sequence length for channels
|
||||||
|
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
|
||||||
|
|
||||||
|
# Project back to embed dim
|
||||||
|
x = checkpoint(self.project_down, x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TransformerUpsampleBlock1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
embed_dim,
|
||||||
|
depth = 3,
|
||||||
|
heads = 12,
|
||||||
|
upsample_ratio = 2,
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.upsample_ratio = upsample_ratio
|
||||||
|
|
||||||
|
self.transformer = ContinuousLocalTransformer(
|
||||||
|
dim=embed_dim,
|
||||||
|
depth=depth,
|
||||||
|
heads=heads,
|
||||||
|
local_attn_window_size = local_attn_window_size,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
# Project to embed dim
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
|
||||||
|
# Project to increase channel dim
|
||||||
|
x = checkpoint(self.project_up, x)
|
||||||
|
|
||||||
|
# Trade channels for sequence length
|
||||||
|
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
|
||||||
|
|
||||||
|
# Compute
|
||||||
|
x = self.transformer(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
embed_dims = [96, 192, 384, 768],
|
||||||
|
heads = [12, 12, 12, 12],
|
||||||
|
depths = [3, 3, 3, 3],
|
||||||
|
ratios = [2, 2, 2, 2],
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for layer in range(len(depths)):
|
||||||
|
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
TransformerDownsampleBlock1D(
|
||||||
|
in_channels = prev_dim,
|
||||||
|
embed_dim = embed_dims[layer],
|
||||||
|
heads = heads[layer],
|
||||||
|
depth = depths[layer],
|
||||||
|
downsample_ratio = ratios[layer],
|
||||||
|
local_attn_window_size = local_attn_window_size,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||||
|
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
x = self.layers(x)
|
||||||
|
x = checkpoint(self.project_out, x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
embed_dims = [768, 384, 192, 96],
|
||||||
|
heads = [12, 12, 12, 12],
|
||||||
|
depths = [3, 3, 3, 3],
|
||||||
|
ratios = [2, 2, 2, 2],
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for layer in range(len(depths)):
|
||||||
|
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
TransformerUpsampleBlock1D(
|
||||||
|
in_channels = prev_dim,
|
||||||
|
embed_dim = embed_dims[layer],
|
||||||
|
heads = heads[layer],
|
||||||
|
depth = depths[layer],
|
||||||
|
upsample_ratio = ratios[layer],
|
||||||
|
local_attn_window_size = local_attn_window_size,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||||
|
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
x = self.layers(x)
|
||||||
|
x = checkpoint(self.project_out, x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
return x
|
||||||
393
stable_audio_tools/models/pqmf.py
Normal file
393
stable_audio_tools/models/pqmf.py
Normal file
|
|
@ -0,0 +1,393 @@
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from scipy.optimize import fmin
|
||||||
|
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
|
||||||
|
|
||||||
|
class PQMF(nn.Module):
|
||||||
|
"""
|
||||||
|
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
|
||||||
|
Uses polyphase representation which is computationally more efficient for real-time.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
|
||||||
|
- num_bands (int): Number of desired frequency bands. It must be a power of 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, attenuation, num_bands):
|
||||||
|
super(PQMF, self).__init__()
|
||||||
|
|
||||||
|
# Ensure num_bands is a power of 2
|
||||||
|
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
|
||||||
|
assert is_power_of_2, "'num_bands' must be a power of 2."
|
||||||
|
|
||||||
|
# Create the prototype filter
|
||||||
|
prototype_filter = design_prototype_filter(attenuation, num_bands)
|
||||||
|
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
|
||||||
|
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
|
||||||
|
|
||||||
|
# Register filters and settings
|
||||||
|
self.register_buffer("filter_bank", padded_filter_bank)
|
||||||
|
self.register_buffer("prototype", prototype_filter)
|
||||||
|
self.num_bands = num_bands
|
||||||
|
|
||||||
|
def forward(self, signal):
|
||||||
|
"""Decompose the signal into multiple frequency bands."""
|
||||||
|
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it
|
||||||
|
signal = prepare_signal_dimensions(signal)
|
||||||
|
# The signal length must be a multiple of num_bands. Pad it with zeros.
|
||||||
|
signal = pad_signal(signal, self.num_bands)
|
||||||
|
# run it
|
||||||
|
signal = polyphase_analysis(signal, self.filter_bank)
|
||||||
|
return apply_alias_cancellation(signal)
|
||||||
|
|
||||||
|
def inverse(self, bands):
|
||||||
|
"""Reconstruct the original signal from the frequency bands."""
|
||||||
|
bands = apply_alias_cancellation(bands)
|
||||||
|
return polyphase_synthesis(bands, self.filter_bank)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_signal_dimensions(signal):
|
||||||
|
"""
|
||||||
|
Rearrange signal into Batch x Channels x Length.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal : torch.Tensor or numpy.ndarray
|
||||||
|
The input signal.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Preprocessed signal tensor.
|
||||||
|
"""
|
||||||
|
# Convert numpy to torch tensor
|
||||||
|
if isinstance(signal, np.ndarray):
|
||||||
|
signal = torch.from_numpy(signal)
|
||||||
|
|
||||||
|
# Ensure tensor
|
||||||
|
if not isinstance(signal, torch.Tensor):
|
||||||
|
raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
|
||||||
|
|
||||||
|
# Modify dimension of signal to Batch x Channels x Length
|
||||||
|
if signal.dim() == 1:
|
||||||
|
# This is just a mono signal. Unsqueeze to 1 x 1 x Length
|
||||||
|
signal = signal.unsqueeze(0).unsqueeze(0)
|
||||||
|
elif signal.dim() == 2:
|
||||||
|
# This is a multi-channel signal (e.g. stereo)
|
||||||
|
# Rearrange so that larger dimension (Length) is last
|
||||||
|
if signal.shape[0] > signal.shape[1]:
|
||||||
|
signal = signal.T
|
||||||
|
# Unsqueeze to 1 x Channels x Length
|
||||||
|
signal = signal.unsqueeze(0)
|
||||||
|
return signal
|
||||||
|
|
||||||
|
def pad_signal(signal, num_bands):
|
||||||
|
"""
|
||||||
|
Pads the signal to make its length divisible by the given number of bands.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal : torch.Tensor
|
||||||
|
The input signal tensor, where the last dimension represents the signal length.
|
||||||
|
|
||||||
|
num_bands : int
|
||||||
|
The number of bands by which the signal length should be divisible.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
The padded signal tensor. If the original signal length was already divisible
|
||||||
|
by num_bands, returns the original signal unchanged.
|
||||||
|
"""
|
||||||
|
remainder = signal.shape[-1] % num_bands
|
||||||
|
if remainder > 0:
|
||||||
|
padding_size = num_bands - remainder
|
||||||
|
signal = nn.functional.pad(signal, (0, padding_size))
|
||||||
|
return signal
|
||||||
|
|
||||||
|
def generate_modulated_filter_bank(prototype_filter, num_bands):
|
||||||
|
"""
|
||||||
|
Generate a QMF bank of cosine modulated filters based on a given prototype filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prototype_filter : torch.Tensor
|
||||||
|
The prototype filter used as the basis for modulation.
|
||||||
|
num_bands : int
|
||||||
|
The number of desired subbands or filters.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
A bank of cosine modulated filters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Initialize indices for modulation.
|
||||||
|
subband_indices = torch.arange(num_bands).reshape(-1, 1)
|
||||||
|
|
||||||
|
# Calculate the length of the prototype filter.
|
||||||
|
filter_length = prototype_filter.shape[-1]
|
||||||
|
|
||||||
|
# Generate symmetric time indices centered around zero.
|
||||||
|
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
|
||||||
|
|
||||||
|
# Calculate phase offsets to ensure orthogonality between subbands.
|
||||||
|
phase_offsets = (-1)**subband_indices * np.pi / 4
|
||||||
|
|
||||||
|
# Compute the cosine modulation function.
|
||||||
|
modulation = torch.cos(
|
||||||
|
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply modulation to the prototype filter.
|
||||||
|
modulated_filters = 2 * prototype_filter * modulation
|
||||||
|
|
||||||
|
return modulated_filters
|
||||||
|
|
||||||
|
|
||||||
|
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
|
||||||
|
"""
|
||||||
|
Design a lowpass filter using the Kaiser window.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
angular_cutoff : float
|
||||||
|
The angular frequency cutoff of the filter.
|
||||||
|
attenuation : float
|
||||||
|
The desired stopband attenuation in decibels (dB).
|
||||||
|
filter_length : int, optional
|
||||||
|
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ndarray
|
||||||
|
The designed lowpass filter coefficients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
|
||||||
|
|
||||||
|
# Ensure the estimated length is odd.
|
||||||
|
estimated_length = 2 * (estimated_length // 2) + 1
|
||||||
|
|
||||||
|
if filter_length is None:
|
||||||
|
filter_length = estimated_length
|
||||||
|
|
||||||
|
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
|
||||||
|
"""
|
||||||
|
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
angular_cutoff : float
|
||||||
|
Angular frequency cutoff of the filter.
|
||||||
|
attenuation : float
|
||||||
|
Desired stopband attenuation in dB.
|
||||||
|
num_bands : int
|
||||||
|
Number of bands for the multiband filter system.
|
||||||
|
filter_length : int, optional
|
||||||
|
Desired length of the filter.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
The computed objective (loss) value for the given filter specs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
|
||||||
|
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
|
||||||
|
|
||||||
|
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
|
||||||
|
|
||||||
|
|
||||||
|
def design_prototype_filter(attenuation, num_bands, filter_length=None):
|
||||||
|
"""
|
||||||
|
Design the optimal prototype filter for a multiband system given the desired specs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
attenuation : float
|
||||||
|
The desired stopband attenuation in dB.
|
||||||
|
num_bands : int
|
||||||
|
Number of bands for the multiband filter system.
|
||||||
|
filter_length : int, optional
|
||||||
|
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ndarray
|
||||||
|
The optimal prototype filter coefficients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
|
||||||
|
1 / num_bands, disp=0)[0]
|
||||||
|
|
||||||
|
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
|
||||||
|
return torch.tensor(prototype_filter, dtype=torch.float32)
|
||||||
|
|
||||||
|
def pad_to_nearest_power_of_two(x):
|
||||||
|
"""
|
||||||
|
Pads the input tensor 'x' on both sides such that its last dimension
|
||||||
|
becomes the nearest larger power of two.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
x : torch.Tensor
|
||||||
|
The input tensor to be padded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor
|
||||||
|
The padded tensor.
|
||||||
|
"""
|
||||||
|
current_length = x.shape[-1]
|
||||||
|
target_length = 2**math.ceil(math.log2(current_length))
|
||||||
|
|
||||||
|
total_padding = target_length - current_length
|
||||||
|
left_padding = total_padding // 2
|
||||||
|
right_padding = total_padding - left_padding
|
||||||
|
|
||||||
|
return nn.functional.pad(x, (left_padding, right_padding))
|
||||||
|
|
||||||
|
def apply_alias_cancellation(x):
|
||||||
|
"""
|
||||||
|
Applies alias cancellation by inverting the sign of every
|
||||||
|
second element of every second row, starting from the second
|
||||||
|
row's first element in a tensor.
|
||||||
|
|
||||||
|
This operation helps ensure that the aliasing introduced in
|
||||||
|
each band during the decomposition will be counteracted during
|
||||||
|
the reconstruction.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
x : torch.Tensor
|
||||||
|
The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor
|
||||||
|
Tensor with specific elements' sign inverted for alias cancellation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a mask of the same shape as 'x', initialized with all ones
|
||||||
|
mask = torch.ones_like(x)
|
||||||
|
|
||||||
|
# Update specific elements in the mask to -1 to perform inversion
|
||||||
|
mask[..., 1::2, ::2] = -1
|
||||||
|
|
||||||
|
# Apply the mask to the input tensor 'x'
|
||||||
|
return x * mask
|
||||||
|
|
||||||
|
def ensure_odd_length(tensor):
|
||||||
|
"""
|
||||||
|
Pads the last dimension of a tensor to ensure its size is odd.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
tensor : torch.Tensor
|
||||||
|
Input tensor whose last dimension might need padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor
|
||||||
|
The original tensor if its last dimension was already odd,
|
||||||
|
or the padded tensor with an odd-sized last dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_dim_size = tensor.shape[-1]
|
||||||
|
|
||||||
|
if last_dim_size % 2 == 0:
|
||||||
|
tensor = nn.functional.pad(tensor, (0, 1))
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def polyphase_analysis(signal, filter_bank):
|
||||||
|
"""
|
||||||
|
Applies the polyphase method to efficiently analyze the signal using a filter bank.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
signal : torch.Tensor
|
||||||
|
Input signal tensor with shape (Batch x Channels x Length).
|
||||||
|
|
||||||
|
filter_bank : torch.Tensor
|
||||||
|
Filter bank tensor with shape (Bands x Length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor
|
||||||
|
Signal split into sub-bands. (Batch x Channels x Bands x Length)
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_bands = filter_bank.shape[0]
|
||||||
|
num_channels = signal.shape[1]
|
||||||
|
|
||||||
|
# Rearrange signal for polyphase processing.
|
||||||
|
# Also combine Batch x Channel into one dimension for now.
|
||||||
|
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
|
||||||
|
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
|
||||||
|
|
||||||
|
# Rearrange the filter bank for matching signal shape
|
||||||
|
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
|
||||||
|
|
||||||
|
# Apply convolution with appropriate padding to maintain spatial dimensions
|
||||||
|
padding = filter_bank.shape[-1] // 2
|
||||||
|
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
|
||||||
|
|
||||||
|
# Truncate the last dimension post-convolution to adjust the output shape
|
||||||
|
filtered_signal = filtered_signal[..., :-1]
|
||||||
|
# Rearrange the first dimension back into Batch x Channels
|
||||||
|
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
|
||||||
|
|
||||||
|
return filtered_signal
|
||||||
|
|
||||||
|
def polyphase_synthesis(signal, filter_bank):
|
||||||
|
"""
|
||||||
|
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal : torch.Tensor
|
||||||
|
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
|
||||||
|
|
||||||
|
filter_bank : torch.Tensor
|
||||||
|
Analysis filter bank (shape: Bands x Length).
|
||||||
|
|
||||||
|
should_rearrange : bool, optional
|
||||||
|
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Reconstructed signal (shape: Batch x Channels X Length)
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_bands = filter_bank.shape[0]
|
||||||
|
num_channels = signal.shape[1]
|
||||||
|
|
||||||
|
# Rearrange the filter bank
|
||||||
|
filter_bank = filter_bank.flip(-1)
|
||||||
|
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
|
||||||
|
|
||||||
|
# Combine Batch x Channels into one dimension for now.
|
||||||
|
signal = rearrange(signal, "b c n t -> (b c) n t")
|
||||||
|
|
||||||
|
# Apply convolution with appropriate padding
|
||||||
|
padding_amount = filter_bank.shape[-1] // 2 + 1
|
||||||
|
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
|
||||||
|
|
||||||
|
# Scale the result
|
||||||
|
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
|
||||||
|
|
||||||
|
# Reorganize the output and truncate
|
||||||
|
reconstructed_signal = reconstructed_signal.flip(1)
|
||||||
|
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
|
||||||
|
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
|
||||||
|
|
||||||
|
return reconstructed_signal
|
||||||
25
stable_audio_tools/models/pretrained.py
Normal file
25
stable_audio_tools/models/pretrained.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from .factory import create_model_from_config
|
||||||
|
from .utils import load_ckpt_state_dict
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
def get_pretrained_model(name: str):
|
||||||
|
|
||||||
|
model_config_path = hf_hub_download(name, filename="config.json", repo_type='model')
|
||||||
|
|
||||||
|
with open(model_config_path) as f:
|
||||||
|
model_config = json.load(f)
|
||||||
|
|
||||||
|
model = create_model_from_config(model_config)
|
||||||
|
|
||||||
|
# Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
|
||||||
|
try:
|
||||||
|
model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model')
|
||||||
|
except Exception as e:
|
||||||
|
model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model')
|
||||||
|
|
||||||
|
model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
|
||||||
|
|
||||||
|
return model, model_config
|
||||||
258
stable_audio_tools/models/pretransforms.py
Normal file
258
stable_audio_tools/models/pretransforms.py
Normal file
|
|
@ -0,0 +1,258 @@
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class Pretransform(nn.Module):
|
||||||
|
def __init__(self, enable_grad, io_channels, is_discrete):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_discrete = is_discrete
|
||||||
|
self.io_channels = io_channels
|
||||||
|
self.encoded_channels = None
|
||||||
|
self.downsampling_ratio = None
|
||||||
|
|
||||||
|
self.enable_grad = enable_grad
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokenize(self, x):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class AutoencoderPretransform(Pretransform):
|
||||||
|
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
|
||||||
|
super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
|
||||||
|
self.model = model
|
||||||
|
self.model.requires_grad_(False).eval()
|
||||||
|
self.scale=scale
|
||||||
|
self.downsampling_ratio = model.downsampling_ratio
|
||||||
|
self.io_channels = model.io_channels
|
||||||
|
self.sample_rate = model.sample_rate
|
||||||
|
|
||||||
|
self.model_half = model_half
|
||||||
|
self.iterate_batch = iterate_batch
|
||||||
|
|
||||||
|
self.encoded_channels = model.latent_dim
|
||||||
|
|
||||||
|
self.chunked = chunked
|
||||||
|
self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||||
|
self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
self.model.half()
|
||||||
|
|
||||||
|
def encode(self, x, **kwargs):
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
x = x.half()
|
||||||
|
self.model.to(torch.float16)
|
||||||
|
|
||||||
|
encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
encoded = encoded.float()
|
||||||
|
|
||||||
|
return encoded / self.scale
|
||||||
|
|
||||||
|
def decode(self, z, **kwargs):
|
||||||
|
z = z * self.scale
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
z = z.half()
|
||||||
|
self.model.to(torch.float16)
|
||||||
|
|
||||||
|
decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
decoded = decoded.float()
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def tokenize(self, x, **kwargs):
|
||||||
|
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
|
||||||
|
|
||||||
|
_, info = self.model.encode(x, return_info = True, **kwargs)
|
||||||
|
|
||||||
|
return info[self.model.bottleneck.tokens_id]
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens, **kwargs):
|
||||||
|
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
|
||||||
|
|
||||||
|
return self.model.decode_tokens(tokens, **kwargs)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
self.model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
class WaveletPretransform(Pretransform):
|
||||||
|
def __init__(self, channels, levels, wavelet):
|
||||||
|
super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
|
||||||
|
|
||||||
|
from .wavelets import WaveletEncode1d, WaveletDecode1d
|
||||||
|
|
||||||
|
self.encoder = WaveletEncode1d(channels, levels, wavelet)
|
||||||
|
self.decoder = WaveletDecode1d(channels, levels, wavelet)
|
||||||
|
|
||||||
|
self.downsampling_ratio = 2 ** levels
|
||||||
|
self.io_channels = channels
|
||||||
|
self.encoded_channels = channels * self.downsampling_ratio
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
class PQMFPretransform(Pretransform):
|
||||||
|
def __init__(self, attenuation=100, num_bands=16):
|
||||||
|
# TODO: Fix PQMF to take in in-channels
|
||||||
|
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
|
||||||
|
from .pqmf import PQMF
|
||||||
|
self.pqmf = PQMF(attenuation, num_bands)
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
# x is (Batch x Channels x Time)
|
||||||
|
x = self.pqmf.forward(x)
|
||||||
|
# pqmf.forward returns (Batch x Channels x Bands x Time)
|
||||||
|
# but Pretransform needs Batch x Channels x Time
|
||||||
|
# so concatenate channels and bands into one axis
|
||||||
|
return rearrange(x, "b c n t -> b (c n) t")
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
|
||||||
|
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
|
||||||
|
# returns (Batch x Channels x Time)
|
||||||
|
return self.pqmf.inverse(x)
|
||||||
|
|
||||||
|
class PretrainedDACPretransform(Pretransform):
|
||||||
|
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
|
||||||
|
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||||
|
|
||||||
|
import dac
|
||||||
|
|
||||||
|
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
|
||||||
|
|
||||||
|
self.model = dac.DAC.load(model_path)
|
||||||
|
|
||||||
|
self.quantize_on_decode = quantize_on_decode
|
||||||
|
|
||||||
|
if model_type == "44khz":
|
||||||
|
self.downsampling_ratio = 512
|
||||||
|
else:
|
||||||
|
self.downsampling_ratio = 320
|
||||||
|
|
||||||
|
self.io_channels = 1
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
self.chunked = chunked
|
||||||
|
|
||||||
|
self.encoded_channels = self.model.latent_dim
|
||||||
|
|
||||||
|
self.num_quantizers = self.model.n_codebooks
|
||||||
|
|
||||||
|
self.codebook_size = self.model.codebook_size
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
|
||||||
|
latents = self.model.encoder(x)
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
output = latents
|
||||||
|
else:
|
||||||
|
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||||
|
output = z
|
||||||
|
|
||||||
|
if self.scale != 1.0:
|
||||||
|
output = output / self.scale
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
|
||||||
|
if self.scale != 1.0:
|
||||||
|
z = z * self.scale
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||||
|
|
||||||
|
return self.model.decode(z)
|
||||||
|
|
||||||
|
def tokenize(self, x):
|
||||||
|
return self.model.encode(x)[1]
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens):
|
||||||
|
latents = self.model.quantizer.from_codes(tokens)
|
||||||
|
return self.model.decode(latents)
|
||||||
|
|
||||||
|
class AudiocraftCompressionPretransform(Pretransform):
|
||||||
|
def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
|
||||||
|
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from audiocraft.models import CompressionModel
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
|
||||||
|
|
||||||
|
self.model = CompressionModel.get_pretrained(model_type)
|
||||||
|
|
||||||
|
self.quantize_on_decode = quantize_on_decode
|
||||||
|
|
||||||
|
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
|
||||||
|
|
||||||
|
self.sample_rate = self.model.sample_rate
|
||||||
|
|
||||||
|
self.io_channels = self.model.channels
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
#self.encoded_channels = self.model.latent_dim
|
||||||
|
|
||||||
|
self.num_quantizers = self.model.num_codebooks
|
||||||
|
|
||||||
|
self.codebook_size = self.model.cardinality
|
||||||
|
|
||||||
|
self.model.to(torch.float16).eval().requires_grad_(False)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
|
||||||
|
assert False, "Audiocraft compression models do not support continuous encoding"
|
||||||
|
|
||||||
|
# latents = self.model.encoder(x)
|
||||||
|
|
||||||
|
# if self.quantize_on_decode:
|
||||||
|
# output = latents
|
||||||
|
# else:
|
||||||
|
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||||
|
# output = z
|
||||||
|
|
||||||
|
# if self.scale != 1.0:
|
||||||
|
# output = output / self.scale
|
||||||
|
|
||||||
|
# return output
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
|
||||||
|
assert False, "Audiocraft compression models do not support continuous decoding"
|
||||||
|
|
||||||
|
# if self.scale != 1.0:
|
||||||
|
# z = z * self.scale
|
||||||
|
|
||||||
|
# if self.quantize_on_decode:
|
||||||
|
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||||
|
|
||||||
|
# return self.model.decode(z)
|
||||||
|
|
||||||
|
def tokenize(self, x):
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
return self.model.encode(x.to(torch.float16))[0]
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens):
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
return self.model.decode(tokens)
|
||||||
190
stable_audio_tools/models/temptransformer.py
Normal file
190
stable_audio_tools/models/temptransformer.py
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn, einsum
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
|
class Residual(nn.Module):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
return self.fn(x, **kwargs) + x
|
||||||
|
|
||||||
|
class SA_PreNorm(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.fn = fn
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
return self.fn(self.norm(x), **kwargs)
|
||||||
|
|
||||||
|
class SA_FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
class SA_Attention(nn.Module):
|
||||||
|
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
project_out = not (heads == 1 and dim_head == dim)
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
) if project_out else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, n, _, h = *x.shape, self.heads
|
||||||
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||||
|
|
||||||
|
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||||
|
|
||||||
|
attn = dots.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ReAttention(nn.Module):
|
||||||
|
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||||
|
|
||||||
|
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
|
||||||
|
|
||||||
|
self.reattn_norm = nn.Sequential(
|
||||||
|
Rearrange('b h i j -> b i j h'),
|
||||||
|
nn.LayerNorm(heads),
|
||||||
|
Rearrange('b i j h -> b h i j')
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, n, _, h = *x.shape, self.heads
|
||||||
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||||
|
attn = dots.softmax(dim=-1)
|
||||||
|
|
||||||
|
# re-attention
|
||||||
|
|
||||||
|
attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
|
||||||
|
attn = self.reattn_norm(attn)
|
||||||
|
|
||||||
|
# aggregate and out
|
||||||
|
|
||||||
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class LeFF(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
scale_dim = dim*scale
|
||||||
|
self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
|
||||||
|
Rearrange('b n c -> b c n'),
|
||||||
|
nn.BatchNorm1d(scale_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
Rearrange('b c (h w) -> b c h w', h=14, w=14)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
|
||||||
|
nn.BatchNorm2d(scale_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
Rearrange('b c h w -> b (h w) c', h=14, w=14)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
|
||||||
|
Rearrange('b n c -> b c n'),
|
||||||
|
nn.BatchNorm1d(dim),
|
||||||
|
nn.GELU(),
|
||||||
|
Rearrange('b c n -> b n c')
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.up_proj(x)
|
||||||
|
x = self.depth_conv(x)
|
||||||
|
x = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LCAttention(nn.Module):
|
||||||
|
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
project_out = not (heads == 1 and dim_head == dim)
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
) if project_out else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, n, _, h = *x.shape, self.heads
|
||||||
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||||
|
q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query
|
||||||
|
|
||||||
|
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||||
|
|
||||||
|
attn = dots.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class SA_Transformer(nn.Module):
|
||||||
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(nn.ModuleList([
|
||||||
|
SA_PreNorm(dim, SA_Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||||
|
SA_PreNorm(dim, SA_FeedForward(dim, mlp_dim, dropout = dropout))
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
x = attn(x) + x
|
||||||
|
x = ff(x) + x
|
||||||
|
return self.norm(x)
|
||||||
812
stable_audio_tools/models/transformer.py
Normal file
812
stable_audio_tools/models/transformer.py
Normal file
|
|
@ -0,0 +1,812 @@
|
||||||
|
from functools import reduce, partial
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn, einsum
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
from typing import Callable, Literal
|
||||||
|
import warnings
|
||||||
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
||||||
|
except ImportError as e:
|
||||||
|
print(e)
|
||||||
|
print('flash_attn not installed, disabling Flash Attention')
|
||||||
|
flash_attn_kvpacked_func = None
|
||||||
|
flash_attn_func = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import natten
|
||||||
|
except ImportError:
|
||||||
|
natten = None
|
||||||
|
|
||||||
|
def checkpoint(function, *args, **kwargs):
|
||||||
|
kwargs.setdefault("use_reentrant", False)
|
||||||
|
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
|
||||||
|
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
|
||||||
|
|
||||||
|
def create_causal_mask(i, j, device):
|
||||||
|
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
||||||
|
|
||||||
|
def or_reduce(masks):
|
||||||
|
head, *body = masks
|
||||||
|
for rest in body:
|
||||||
|
head = head | rest
|
||||||
|
return head
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
|
||||||
|
class AbsolutePositionalEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, max_seq_len):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** -0.5
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.emb = nn.Embedding(max_seq_len, dim)
|
||||||
|
|
||||||
|
def forward(self, x, pos = None, seq_start_pos = None):
|
||||||
|
seq_len, device = x.shape[1], x.device
|
||||||
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
||||||
|
|
||||||
|
if pos is None:
|
||||||
|
pos = torch.arange(seq_len, device = device)
|
||||||
|
|
||||||
|
if seq_start_pos is not None:
|
||||||
|
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||||
|
|
||||||
|
pos_emb = self.emb(pos)
|
||||||
|
pos_emb = pos_emb * self.scale
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
class ScaledSinusoidalEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, theta = 10000):
|
||||||
|
super().__init__()
|
||||||
|
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
||||||
|
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
||||||
|
|
||||||
|
half_dim = dim // 2
|
||||||
|
freq_seq = torch.arange(half_dim).float() / half_dim
|
||||||
|
inv_freq = theta ** -freq_seq
|
||||||
|
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
||||||
|
|
||||||
|
def forward(self, x, pos = None, seq_start_pos = None):
|
||||||
|
seq_len, device = x.shape[1], x.device
|
||||||
|
|
||||||
|
if pos is None:
|
||||||
|
pos = torch.arange(seq_len, device = device)
|
||||||
|
|
||||||
|
if seq_start_pos is not None:
|
||||||
|
pos = pos - seq_start_pos[..., None]
|
||||||
|
|
||||||
|
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||||
|
return emb * self.scale
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
use_xpos = False,
|
||||||
|
scale_base = 512,
|
||||||
|
interpolation_factor = 1.,
|
||||||
|
base = 10000,
|
||||||
|
base_rescale_factor = 1.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
|
# has some connection to NTK literature
|
||||||
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||||
|
|
||||||
|
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
self.register_buffer('inv_freq', inv_freq)
|
||||||
|
|
||||||
|
assert interpolation_factor >= 1.
|
||||||
|
self.interpolation_factor = interpolation_factor
|
||||||
|
|
||||||
|
if not use_xpos:
|
||||||
|
self.register_buffer('scale', None)
|
||||||
|
return
|
||||||
|
|
||||||
|
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||||
|
|
||||||
|
self.scale_base = scale_base
|
||||||
|
self.register_buffer('scale', scale)
|
||||||
|
|
||||||
|
def forward_from_seq_len(self, seq_len):
|
||||||
|
device = self.inv_freq.device
|
||||||
|
|
||||||
|
t = torch.arange(seq_len, device = device)
|
||||||
|
return self.forward(t)
|
||||||
|
|
||||||
|
@autocast(enabled = False)
|
||||||
|
def forward(self, t):
|
||||||
|
device = self.inv_freq.device
|
||||||
|
|
||||||
|
t = t.to(torch.float32)
|
||||||
|
|
||||||
|
t = t / self.interpolation_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||||
|
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||||
|
|
||||||
|
if self.scale is None:
|
||||||
|
return freqs, 1.
|
||||||
|
|
||||||
|
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||||
|
scale = self.scale ** rearrange(power, 'n -> n 1')
|
||||||
|
scale = torch.cat((scale, scale), dim = -1)
|
||||||
|
|
||||||
|
return freqs, scale
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
||||||
|
x1, x2 = x.unbind(dim = -2)
|
||||||
|
return torch.cat((-x2, x1), dim = -1)
|
||||||
|
|
||||||
|
@autocast(enabled = False)
|
||||||
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
||||||
|
out_dtype = t.dtype
|
||||||
|
|
||||||
|
# cast to float32 if necessary for numerical stability
|
||||||
|
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
||||||
|
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
||||||
|
freqs, t = freqs.to(dtype), t.to(dtype)
|
||||||
|
freqs = freqs[-seq_len:, :]
|
||||||
|
|
||||||
|
if t.ndim == 4 and freqs.ndim == 3:
|
||||||
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
||||||
|
|
||||||
|
# partial rotary embeddings, Wang et al. GPT-J
|
||||||
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
||||||
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||||
|
|
||||||
|
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
||||||
|
|
||||||
|
return torch.cat((t, t_unrotated), dim = -1)
|
||||||
|
|
||||||
|
# norms
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, bias=False, fix_scale=False):
|
||||||
|
"""
|
||||||
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if fix_scale:
|
||||||
|
self.register_buffer("gamma", torch.ones(dim))
|
||||||
|
else:
|
||||||
|
self.gamma = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.beta = nn.Parameter(torch.zeros(dim))
|
||||||
|
else:
|
||||||
|
self.register_buffer("beta", torch.zeros(dim))
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
|
||||||
|
|
||||||
|
# feedforward
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
activation: Callable,
|
||||||
|
use_conv = False,
|
||||||
|
conv_kernel_size = 3,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.act = activation
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
|
||||||
|
self.use_conv = use_conv
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_conv:
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.proj(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
else:
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
x, gate = x.chunk(2, dim = -1)
|
||||||
|
return x * self.act(gate)
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_out = None,
|
||||||
|
mult = 4,
|
||||||
|
no_bias = False,
|
||||||
|
glu = True,
|
||||||
|
use_conv = False,
|
||||||
|
conv_kernel_size = 3,
|
||||||
|
zero_init_output = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
|
||||||
|
# Default to SwiGLU
|
||||||
|
|
||||||
|
activation = nn.SiLU()
|
||||||
|
|
||||||
|
dim_out = dim if dim_out is None else dim_out
|
||||||
|
|
||||||
|
if glu:
|
||||||
|
linear_in = GLU(dim, inner_dim, activation)
|
||||||
|
else:
|
||||||
|
linear_in = nn.Sequential(
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
activation
|
||||||
|
)
|
||||||
|
|
||||||
|
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
|
||||||
|
|
||||||
|
# init last linear layer to 0
|
||||||
|
if zero_init_output:
|
||||||
|
nn.init.zeros_(linear_out.weight)
|
||||||
|
if not no_bias:
|
||||||
|
nn.init.zeros_(linear_out.bias)
|
||||||
|
|
||||||
|
|
||||||
|
self.ff = nn.Sequential(
|
||||||
|
linear_in,
|
||||||
|
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||||
|
linear_out,
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ff(x)
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_heads = 64,
|
||||||
|
dim_context = None,
|
||||||
|
causal = False,
|
||||||
|
zero_init_output=True,
|
||||||
|
qk_norm: Literal['l2', 'ln', 'none'] = 'none',
|
||||||
|
natten_kernel_size = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_heads = dim_heads
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
dim_kv = dim_context if dim_context is not None else dim
|
||||||
|
|
||||||
|
self.num_heads = dim // dim_heads
|
||||||
|
self.kv_heads = dim_kv // dim_heads
|
||||||
|
|
||||||
|
if dim_context is not None:
|
||||||
|
self.to_q = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
|
||||||
|
else:
|
||||||
|
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
|
if zero_init_output:
|
||||||
|
nn.init.zeros_(self.to_out.weight)
|
||||||
|
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
|
||||||
|
if self.qk_norm == "ln":
|
||||||
|
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||||
|
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||||
|
|
||||||
|
# Using 1d neighborhood attention
|
||||||
|
self.natten_kernel_size = natten_kernel_size
|
||||||
|
if natten_kernel_size is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
||||||
|
|
||||||
|
self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
|
||||||
|
|
||||||
|
self.sdp_kwargs = dict(
|
||||||
|
enable_flash = True,
|
||||||
|
enable_math = True,
|
||||||
|
enable_mem_efficient = True
|
||||||
|
)
|
||||||
|
|
||||||
|
def flash_attn(
|
||||||
|
self,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
mask = None,
|
||||||
|
causal = None
|
||||||
|
):
|
||||||
|
batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
|
||||||
|
kv_heads = k.shape[1]
|
||||||
|
# Recommended for multi-query single-key-value attention by Tri Dao
|
||||||
|
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
||||||
|
|
||||||
|
if heads != kv_heads:
|
||||||
|
# Repeat interleave kv_heads to match q_heads
|
||||||
|
heads_per_kv_head = heads // kv_heads
|
||||||
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||||
|
|
||||||
|
if k.ndim == 3:
|
||||||
|
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
|
||||||
|
|
||||||
|
if v.ndim == 3:
|
||||||
|
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
|
||||||
|
|
||||||
|
causal = self.causal if causal is None else causal
|
||||||
|
|
||||||
|
if q_len == 1 and causal:
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert mask.ndim == 4
|
||||||
|
mask = mask.expand(batch, heads, q_len, k_len)
|
||||||
|
|
||||||
|
# handle kv cache - this should be bypassable in updated flash attention 2
|
||||||
|
|
||||||
|
if k_len > q_len and causal:
|
||||||
|
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
||||||
|
if mask is None:
|
||||||
|
mask = ~causal_mask
|
||||||
|
else:
|
||||||
|
mask = mask & ~causal_mask
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
# manually handle causal mask, if another mask was given
|
||||||
|
|
||||||
|
row_is_entirely_masked = None
|
||||||
|
|
||||||
|
if mask is not None and causal:
|
||||||
|
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
||||||
|
mask = mask & ~causal_mask
|
||||||
|
|
||||||
|
# protect against an entire row being masked out
|
||||||
|
|
||||||
|
row_is_entirely_masked = ~mask.any(dim = -1)
|
||||||
|
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
|
||||||
|
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
|
||||||
|
out = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask = mask,
|
||||||
|
is_causal = causal
|
||||||
|
)
|
||||||
|
|
||||||
|
# for a row that is entirely masked out, should zero out the output of that row token
|
||||||
|
|
||||||
|
if row_is_entirely_masked is not None:
|
||||||
|
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context = None,
|
||||||
|
mask = None,
|
||||||
|
context_mask = None,
|
||||||
|
rotary_pos_emb = None,
|
||||||
|
causal = None
|
||||||
|
):
|
||||||
|
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||||
|
|
||||||
|
kv_input = context if has_context else x
|
||||||
|
|
||||||
|
if hasattr(self, 'to_q'):
|
||||||
|
# Use separate linear projections for q and k/v
|
||||||
|
q = self.to_q(x)
|
||||||
|
q = rearrange(q, 'b n (h d) -> b h n d', h = h) # [B, 24, 1025, 64]
|
||||||
|
|
||||||
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
||||||
|
else:
|
||||||
|
# Use fused linear projection
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||||
|
|
||||||
|
# Normalize q and k for cosine sim attention
|
||||||
|
if self.qk_norm == "l2":
|
||||||
|
q = F.normalize(q, dim=-1)
|
||||||
|
k = F.normalize(k, dim=-1)
|
||||||
|
elif self.qk_norm == "ln":
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
if rotary_pos_emb is not None and not has_context:
|
||||||
|
freqs, _ = rotary_pos_emb
|
||||||
|
|
||||||
|
q_dtype = q.dtype
|
||||||
|
k_dtype = k.dtype
|
||||||
|
|
||||||
|
q = q.to(torch.float32)
|
||||||
|
k = k.to(torch.float32)
|
||||||
|
freqs = freqs.to(torch.float32)
|
||||||
|
|
||||||
|
q = apply_rotary_pos_emb(q, freqs)
|
||||||
|
k = apply_rotary_pos_emb(k, freqs)
|
||||||
|
|
||||||
|
q = q.to(q_dtype)
|
||||||
|
k = k.to(k_dtype)
|
||||||
|
|
||||||
|
input_mask = context_mask
|
||||||
|
|
||||||
|
if input_mask is None and not has_context:
|
||||||
|
input_mask = mask
|
||||||
|
|
||||||
|
# determine masking
|
||||||
|
masks = []
|
||||||
|
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
||||||
|
|
||||||
|
if input_mask is not None:
|
||||||
|
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
||||||
|
masks.append(~input_mask)
|
||||||
|
|
||||||
|
# Other masks will be added here later
|
||||||
|
|
||||||
|
if len(masks) > 0:
|
||||||
|
final_attn_mask = ~or_reduce(masks)
|
||||||
|
|
||||||
|
n, device = q.shape[-2], q.device
|
||||||
|
|
||||||
|
causal = self.causal if causal is None else causal
|
||||||
|
|
||||||
|
if n == 1 and causal:
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
if self.natten_kernel_size is not None:
|
||||||
|
if natten is None:
|
||||||
|
raise ImportError('natten not installed, please install natten to use neighborhood attention')
|
||||||
|
|
||||||
|
dtype_in = q.dtype
|
||||||
|
q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
|
||||||
|
|
||||||
|
attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1)
|
||||||
|
|
||||||
|
if final_attn_mask is not None:
|
||||||
|
attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
|
||||||
|
|
||||||
|
attn = F.softmax(attn, dim=-1, dtype=torch.float32)
|
||||||
|
|
||||||
|
out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in)
|
||||||
|
|
||||||
|
# Prioritize Flash Attention 2
|
||||||
|
elif self.use_fa_flash:
|
||||||
|
assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
|
||||||
|
# Flash Attention 2 requires FP16 inputs
|
||||||
|
fa_dtype_in = q.dtype
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
|
||||||
|
|
||||||
|
out = flash_attn_func(q, k, v, causal = causal)
|
||||||
|
|
||||||
|
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
|
||||||
|
|
||||||
|
# Fall back to PyTorch implementation
|
||||||
|
elif self.use_pt_flash:
|
||||||
|
out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Fall back to custom implementation
|
||||||
|
|
||||||
|
if h != kv_h:
|
||||||
|
# Repeat interleave kv_heads to match q_heads
|
||||||
|
heads_per_kv_head = h // kv_h
|
||||||
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||||
|
|
||||||
|
scale = 1. / (q.shape[-1] ** 0.5)
|
||||||
|
|
||||||
|
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
||||||
|
|
||||||
|
dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
|
||||||
|
|
||||||
|
i, j, dtype = *dots.shape[-2:], dots.dtype
|
||||||
|
|
||||||
|
mask_value = -torch.finfo(dots.dtype).max
|
||||||
|
|
||||||
|
if final_attn_mask is not None:
|
||||||
|
dots = dots.masked_fill(~final_attn_mask, mask_value)
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
||||||
|
dots = dots.masked_fill(causal_mask, mask_value)
|
||||||
|
|
||||||
|
attn = F.softmax(dots, dim=-1, dtype=torch.float32)
|
||||||
|
attn = attn.type(dtype)
|
||||||
|
|
||||||
|
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
out = rearrange(out, ' b h n d -> b n (h d)')
|
||||||
|
|
||||||
|
# Communicate between heads
|
||||||
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = rearrange(mask, 'b n -> b n 1')
|
||||||
|
out = out.masked_fill(~mask, 0.)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerModule(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
norm_kwargs = {},
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
||||||
|
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||||
|
self.glu = GLU(dim, dim, nn.SiLU())
|
||||||
|
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
||||||
|
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
||||||
|
self.swish = nn.SiLU()
|
||||||
|
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.in_norm(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.pointwise_conv(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
x = self.glu(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
x = self.mid_norm(x)
|
||||||
|
x = self.swish(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.pointwise_conv_2(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_heads = 64,
|
||||||
|
cross_attend = False,
|
||||||
|
dim_context = None,
|
||||||
|
global_cond_dim = None,
|
||||||
|
causal = False,
|
||||||
|
zero_init_branch_outputs = True,
|
||||||
|
conformer = False,
|
||||||
|
layer_ix = -1,
|
||||||
|
remove_norms = False,
|
||||||
|
attn_kwargs = {},
|
||||||
|
ff_kwargs = {},
|
||||||
|
norm_kwargs = {}
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_heads = dim_heads
|
||||||
|
self.cross_attend = cross_attend
|
||||||
|
self.dim_context = dim_context
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||||
|
|
||||||
|
self.self_attn = Attention(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_output=zero_init_branch_outputs,
|
||||||
|
**attn_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if cross_attend:
|
||||||
|
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||||
|
self.cross_attn = Attention(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
dim_context=dim_context,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_output=zero_init_branch_outputs,
|
||||||
|
**attn_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||||
|
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
|
||||||
|
|
||||||
|
self.layer_ix = layer_ix
|
||||||
|
|
||||||
|
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
|
||||||
|
|
||||||
|
self.global_cond_dim = global_cond_dim
|
||||||
|
|
||||||
|
if global_cond_dim is not None:
|
||||||
|
self.to_scale_shift_gate = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(global_cond_dim, dim * 6, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context = None,
|
||||||
|
global_cond=None,
|
||||||
|
mask = None,
|
||||||
|
context_mask = None,
|
||||||
|
rotary_pos_emb = None,
|
||||||
|
adapter=None
|
||||||
|
):
|
||||||
|
|
||||||
|
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: # False
|
||||||
|
|
||||||
|
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
|
||||||
|
|
||||||
|
# self-attention with adaLN
|
||||||
|
residual = x
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
x = x * (1 + scale_self) + shift_self
|
||||||
|
|
||||||
|
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||||
|
x = x * torch.sigmoid(1 - gate_self)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
if context is not None:
|
||||||
|
|
||||||
|
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||||
|
|
||||||
|
if self.conformer is not None:
|
||||||
|
x = x + self.conformer(x)
|
||||||
|
|
||||||
|
# feedforward with adaLN
|
||||||
|
residual = x
|
||||||
|
x = self.ff_norm(x)
|
||||||
|
x = x * (1 + scale_ff) + shift_ff
|
||||||
|
x = self.ff(x)
|
||||||
|
x = x * torch.sigmoid(1 - gate_ff)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
else:
|
||||||
|
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||||
|
|
||||||
|
if context is not None:
|
||||||
|
|
||||||
|
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||||
|
|
||||||
|
if self.conformer is not None:
|
||||||
|
x = x + self.conformer(x)
|
||||||
|
|
||||||
|
x = x + self.ff(self.ff_norm(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ContinuousTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
*,
|
||||||
|
dim_in = None,
|
||||||
|
dim_out = None,
|
||||||
|
dim_heads = 64,
|
||||||
|
cross_attend=False,
|
||||||
|
cond_token_dim=None,
|
||||||
|
global_cond_dim=None,
|
||||||
|
causal=False,
|
||||||
|
rotary_pos_emb=True,
|
||||||
|
zero_init_branch_outputs=True,
|
||||||
|
conformer=False,
|
||||||
|
use_sinusoidal_emb=False,
|
||||||
|
use_abs_pos_emb=False,
|
||||||
|
abs_pos_emb_max_length=10000,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.depth = depth
|
||||||
|
self.causal = causal
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
|
||||||
|
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
|
||||||
|
|
||||||
|
if rotary_pos_emb:
|
||||||
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||||
|
else:
|
||||||
|
self.rotary_pos_emb = None
|
||||||
|
|
||||||
|
self.use_sinusoidal_emb = use_sinusoidal_emb
|
||||||
|
if use_sinusoidal_emb:
|
||||||
|
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
||||||
|
|
||||||
|
self.use_abs_pos_emb = use_abs_pos_emb
|
||||||
|
if use_abs_pos_emb:
|
||||||
|
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
|
||||||
|
|
||||||
|
for i in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
TransformerBlock(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
cross_attend = cross_attend,
|
||||||
|
dim_context = cond_token_dim,
|
||||||
|
global_cond_dim = global_cond_dim,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_branch_outputs = zero_init_branch_outputs,
|
||||||
|
conformer=conformer,
|
||||||
|
layer_ix=i,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
mask = None,
|
||||||
|
prepend_embeds = None,
|
||||||
|
prepend_mask = None,
|
||||||
|
global_cond = None,
|
||||||
|
return_info = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
batch, seq, device = *x.shape[:2], x.device
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"hidden_states": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
x = self.project_in(x)
|
||||||
|
|
||||||
|
if prepend_embeds is not None:
|
||||||
|
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
||||||
|
|
||||||
|
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
||||||
|
|
||||||
|
x = torch.cat((prepend_embeds, x), dim = -2)
|
||||||
|
|
||||||
|
if prepend_mask is not None or mask is not None:
|
||||||
|
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
|
||||||
|
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
|
||||||
|
|
||||||
|
mask = torch.cat((prepend_mask, mask), dim = -1)
|
||||||
|
|
||||||
|
# Attention layers
|
||||||
|
if self.rotary_pos_emb is not None:
|
||||||
|
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||||
|
else:
|
||||||
|
rotary_pos_emb = None
|
||||||
|
|
||||||
|
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||||
|
x = x + self.pos_emb(x)
|
||||||
|
|
||||||
|
# Iterate over the transformer layers
|
||||||
|
for index, layer in enumerate(self.layers):
|
||||||
|
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
info["hidden_states"].append(x)
|
||||||
|
|
||||||
|
x = self.project_out(x)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
|
||||||
|
return x
|
||||||
92
stable_audio_tools/models/utils.py
Normal file
92
stable_audio_tools/models/utils.py
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from torch.nn.utils import remove_weight_norm
|
||||||
|
import warnings
|
||||||
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||||
|
|
||||||
|
|
||||||
|
def load_ckpt_state_dict(ckpt_path):
|
||||||
|
if ckpt_path.endswith(".safetensors"):
|
||||||
|
state_dict = load_file(ckpt_path)
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def remove_weight_norm_from_model(model):
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "weight"):
|
||||||
|
print(f"Removing weight norm from {module}")
|
||||||
|
remove_weight_norm(module)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
|
||||||
|
# License can be found in LICENSES/LICENSE_META.txt
|
||||||
|
|
||||||
|
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
||||||
|
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (torch.Tensor): The input tensor containing probabilities.
|
||||||
|
num_samples (int): Number of samples to draw.
|
||||||
|
replacement (bool): Whether to draw with replacement or not.
|
||||||
|
Keywords args:
|
||||||
|
generator (torch.Generator): A pseudorandom number generator for sampling.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Last dimension contains num_samples indices
|
||||||
|
sampled from the multinomial probability distribution
|
||||||
|
located in the last dimension of tensor input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if num_samples == 1:
|
||||||
|
q = torch.empty_like(input).exponential_(1, generator=generator)
|
||||||
|
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
||||||
|
|
||||||
|
input_ = input.reshape(-1, input.shape[-1])
|
||||||
|
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
||||||
|
output = output_.reshape(*list(input.shape[:-1]), -1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
||||||
|
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||||
|
k (int): The k in “top-k”.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Sampled tokens.
|
||||||
|
"""
|
||||||
|
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
||||||
|
min_value_top_k = top_k_value[..., [-1]]
|
||||||
|
probs *= (probs >= min_value_top_k).float()
|
||||||
|
probs.div_(probs.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = multinomial(probs, num_samples=1)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||||
|
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||||
|
p (int): The p in “top-p”.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Sampled tokens.
|
||||||
|
"""
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort *= (~mask).float()
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
def next_power_of_two(n):
|
||||||
|
return 2 ** (n - 1).bit_length()
|
||||||
|
|
||||||
|
def next_multiple_of_64(n):
|
||||||
|
return ((n + 63) // 64) * 64
|
||||||
82
stable_audio_tools/models/wavelets.py
Normal file
82
stable_audio_tools/models/wavelets.py
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
"""The 1D discrete wavelet transform for PyTorch."""
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
import pywt
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
def get_filter_bank(wavelet):
|
||||||
|
filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank)
|
||||||
|
if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0):
|
||||||
|
filt = filt[:, 1:]
|
||||||
|
return filt
|
||||||
|
|
||||||
|
class WaveletEncode1d(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
channels,
|
||||||
|
levels,
|
||||||
|
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
|
||||||
|
super().__init__()
|
||||||
|
self.wavelet = wavelet
|
||||||
|
self.channels = channels
|
||||||
|
self.levels = levels
|
||||||
|
filt = get_filter_bank(wavelet)
|
||||||
|
assert filt.shape[-1] % 2 == 1
|
||||||
|
kernel = filt[:2, None]
|
||||||
|
kernel = torch.flip(kernel, dims=(-1,))
|
||||||
|
index_i = torch.repeat_interleave(torch.arange(2), channels)
|
||||||
|
index_j = torch.tile(torch.arange(channels), (2,))
|
||||||
|
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
|
||||||
|
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
|
||||||
|
self.register_buffer("kernel", kernel_final)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i in range(self.levels):
|
||||||
|
low, rest = x[:, : self.channels], x[:, self.channels :]
|
||||||
|
pad = self.kernel.shape[-1] // 2
|
||||||
|
low = F.pad(low, (pad, pad), "reflect")
|
||||||
|
low = F.conv1d(low, self.kernel, stride=2)
|
||||||
|
rest = rearrange(
|
||||||
|
rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels
|
||||||
|
)
|
||||||
|
x = torch.cat([low, rest], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WaveletDecode1d(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
channels,
|
||||||
|
levels,
|
||||||
|
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
|
||||||
|
super().__init__()
|
||||||
|
self.wavelet = wavelet
|
||||||
|
self.channels = channels
|
||||||
|
self.levels = levels
|
||||||
|
filt = get_filter_bank(wavelet)
|
||||||
|
assert filt.shape[-1] % 2 == 1
|
||||||
|
kernel = filt[2:, None]
|
||||||
|
index_i = torch.repeat_interleave(torch.arange(2), channels)
|
||||||
|
index_j = torch.tile(torch.arange(channels), (2,))
|
||||||
|
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
|
||||||
|
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
|
||||||
|
self.register_buffer("kernel", kernel_final)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i in range(self.levels):
|
||||||
|
low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :]
|
||||||
|
pad = self.kernel.shape[-1] // 2 + 2
|
||||||
|
low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2)
|
||||||
|
low = F.pad(low, (pad, pad), "reflect")
|
||||||
|
low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2)
|
||||||
|
low = F.conv_transpose1d(
|
||||||
|
low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2
|
||||||
|
)
|
||||||
|
low = low[..., pad - 1 : -pad]
|
||||||
|
rest = rearrange(
|
||||||
|
rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels
|
||||||
|
)
|
||||||
|
x = torch.cat([low, rest], dim=1)
|
||||||
|
return x
|
||||||
1
stable_audio_tools/training/__init__.py
Normal file
1
stable_audio_tools/training/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .factory import create_training_wrapper_from_config, create_demo_callback_from_config
|
||||||
476
stable_audio_tools/training/autoencoders.py
Normal file
476
stable_audio_tools/training/autoencoders.py
Normal file
|
|
@ -0,0 +1,476 @@
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import wandb
|
||||||
|
from einops import rearrange
|
||||||
|
from safetensors.torch import save_file, save_model
|
||||||
|
from ema_pytorch import EMA
|
||||||
|
from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from ..models.autoencoders import AudioAutoencoder
|
||||||
|
from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss
|
||||||
|
from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
|
||||||
|
from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
|
||||||
|
from .utils import create_optimizer_from_config, create_scheduler_from_config
|
||||||
|
|
||||||
|
|
||||||
|
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||||
|
from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
|
||||||
|
|
||||||
|
class AutoencoderTrainingWrapper(pl.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
autoencoder: AudioAutoencoder,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
warmup_steps: int = 0,
|
||||||
|
encoder_freeze_on_warmup: bool = False,
|
||||||
|
sample_rate=48000,
|
||||||
|
loss_config: dict = None,
|
||||||
|
optimizer_configs: dict = None,
|
||||||
|
use_ema: bool = True,
|
||||||
|
ema_copy = None,
|
||||||
|
force_input_mono = False,
|
||||||
|
latent_mask_ratio = 0.0,
|
||||||
|
teacher_model: AudioAutoencoder = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.automatic_optimization = False
|
||||||
|
|
||||||
|
self.autoencoder = autoencoder
|
||||||
|
|
||||||
|
self.warmed_up = False
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.encoder_freeze_on_warmup = encoder_freeze_on_warmup
|
||||||
|
self.lr = lr
|
||||||
|
|
||||||
|
self.force_input_mono = force_input_mono
|
||||||
|
|
||||||
|
self.teacher_model = teacher_model
|
||||||
|
|
||||||
|
if optimizer_configs is None:
|
||||||
|
optimizer_configs ={
|
||||||
|
"autoencoder": {
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"config": {
|
||||||
|
"lr": lr,
|
||||||
|
"betas": (.8, .99)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"discriminator": {
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"config": {
|
||||||
|
"lr": lr,
|
||||||
|
"betas": (.8, .99)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
self.optimizer_configs = optimizer_configs
|
||||||
|
|
||||||
|
if loss_config is None:
|
||||||
|
scales = [2048, 1024, 512, 256, 128, 64, 32]
|
||||||
|
hop_sizes = []
|
||||||
|
win_lengths = []
|
||||||
|
overlap = 0.75
|
||||||
|
for s in scales:
|
||||||
|
hop_sizes.append(int(s * (1 - overlap)))
|
||||||
|
win_lengths.append(s)
|
||||||
|
|
||||||
|
loss_config = {
|
||||||
|
"discriminator": {
|
||||||
|
"type": "encodec",
|
||||||
|
"config": {
|
||||||
|
"n_ffts": scales,
|
||||||
|
"hop_lengths": hop_sizes,
|
||||||
|
"win_lengths": win_lengths,
|
||||||
|
"filters": 32
|
||||||
|
},
|
||||||
|
"weights": {
|
||||||
|
"adversarial": 0.1,
|
||||||
|
"feature_matching": 5.0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"spectral": {
|
||||||
|
"type": "mrstft",
|
||||||
|
"config": {
|
||||||
|
"fft_sizes": scales,
|
||||||
|
"hop_sizes": hop_sizes,
|
||||||
|
"win_lengths": win_lengths,
|
||||||
|
"perceptual_weighting": True
|
||||||
|
},
|
||||||
|
"weights": {
|
||||||
|
"mrstft": 1.0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"time": {
|
||||||
|
"type": "l1",
|
||||||
|
"config": {},
|
||||||
|
"weights": {
|
||||||
|
"l1": 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.loss_config = loss_config
|
||||||
|
|
||||||
|
# Spectral reconstruction loss
|
||||||
|
|
||||||
|
stft_loss_args = loss_config['spectral']['config']
|
||||||
|
|
||||||
|
if self.autoencoder.out_channels == 2:
|
||||||
|
self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
|
||||||
|
self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
|
||||||
|
else:
|
||||||
|
self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
|
||||||
|
|
||||||
|
# Discriminator
|
||||||
|
|
||||||
|
if loss_config['discriminator']['type'] == 'oobleck':
|
||||||
|
self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config'])
|
||||||
|
elif loss_config['discriminator']['type'] == 'encodec':
|
||||||
|
self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config'])
|
||||||
|
elif loss_config['discriminator']['type'] == 'dac':
|
||||||
|
self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config'])
|
||||||
|
|
||||||
|
self.gen_loss_modules = []
|
||||||
|
|
||||||
|
# Adversarial and feature matching losses
|
||||||
|
self.gen_loss_modules += [
|
||||||
|
ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'),
|
||||||
|
ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'),
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.teacher_model is not None:
|
||||||
|
# Distillation losses
|
||||||
|
|
||||||
|
stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25
|
||||||
|
self.gen_loss_modules += [
|
||||||
|
AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss
|
||||||
|
AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder
|
||||||
|
AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder
|
||||||
|
AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
# Reconstruction loss
|
||||||
|
self.gen_loss_modules += [
|
||||||
|
AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.autoencoder.out_channels == 2:
|
||||||
|
|
||||||
|
# Add left and right channel reconstruction losses in addition to the sum and difference
|
||||||
|
self.gen_loss_modules += [
|
||||||
|
AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2),
|
||||||
|
AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.gen_loss_modules += [
|
||||||
|
AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.loss_config['time']['weights']['l1'] > 0.0:
|
||||||
|
self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss'))
|
||||||
|
|
||||||
|
if self.autoencoder.bottleneck is not None:
|
||||||
|
self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config)
|
||||||
|
|
||||||
|
self.losses_gen = MultiLoss(self.gen_loss_modules)
|
||||||
|
|
||||||
|
self.disc_loss_modules = [
|
||||||
|
ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.losses_disc = MultiLoss(self.disc_loss_modules)
|
||||||
|
|
||||||
|
# Set up EMA for model weights
|
||||||
|
self.autoencoder_ema = None
|
||||||
|
|
||||||
|
self.use_ema = use_ema
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.autoencoder_ema = EMA(
|
||||||
|
self.autoencoder,
|
||||||
|
ema_model=ema_copy,
|
||||||
|
beta=0.9999,
|
||||||
|
power=3/4,
|
||||||
|
update_every=1,
|
||||||
|
update_after_step=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.latent_mask_ratio = latent_mask_ratio
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
|
||||||
|
opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters())
|
||||||
|
opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters())
|
||||||
|
|
||||||
|
if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']:
|
||||||
|
sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen)
|
||||||
|
sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc)
|
||||||
|
return [opt_gen, opt_disc], [sched_gen, sched_disc]
|
||||||
|
|
||||||
|
return [opt_gen, opt_disc]
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
reals, _ = batch
|
||||||
|
|
||||||
|
# Remove extra dimension added by WebDataset
|
||||||
|
if reals.ndim == 4 and reals.shape[0] == 1:
|
||||||
|
reals = reals[0]
|
||||||
|
|
||||||
|
if self.global_step >= self.warmup_steps:
|
||||||
|
self.warmed_up = True
|
||||||
|
|
||||||
|
loss_info = {}
|
||||||
|
|
||||||
|
loss_info["reals"] = reals
|
||||||
|
|
||||||
|
encoder_input = reals
|
||||||
|
|
||||||
|
if self.force_input_mono and encoder_input.shape[1] > 1:
|
||||||
|
encoder_input = encoder_input.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
loss_info["encoder_input"] = encoder_input
|
||||||
|
|
||||||
|
data_std = encoder_input.std()
|
||||||
|
|
||||||
|
if self.warmed_up and self.encoder_freeze_on_warmup:
|
||||||
|
with torch.no_grad():
|
||||||
|
latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
|
||||||
|
else:
|
||||||
|
latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
|
||||||
|
|
||||||
|
loss_info["latents"] = latents
|
||||||
|
|
||||||
|
loss_info.update(encoder_info)
|
||||||
|
|
||||||
|
# Encode with teacher model for distillation
|
||||||
|
if self.teacher_model is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_latents = self.teacher_model.encode(encoder_input, return_info=False)
|
||||||
|
loss_info['teacher_latents'] = teacher_latents
|
||||||
|
|
||||||
|
if self.latent_mask_ratio > 0.0:
|
||||||
|
mask = torch.rand_like(latents) < self.latent_mask_ratio
|
||||||
|
latents = torch.where(mask, torch.zeros_like(latents), latents)
|
||||||
|
|
||||||
|
decoded = self.autoencoder.decode(latents)
|
||||||
|
|
||||||
|
loss_info["decoded"] = decoded
|
||||||
|
|
||||||
|
if self.autoencoder.out_channels == 2:
|
||||||
|
loss_info["decoded_left"] = decoded[:, 0:1, :]
|
||||||
|
loss_info["decoded_right"] = decoded[:, 1:2, :]
|
||||||
|
loss_info["reals_left"] = reals[:, 0:1, :]
|
||||||
|
loss_info["reals_right"] = reals[:, 1:2, :]
|
||||||
|
|
||||||
|
# Distillation
|
||||||
|
if self.teacher_model is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_decoded = self.teacher_model.decode(teacher_latents)
|
||||||
|
own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher
|
||||||
|
teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model
|
||||||
|
|
||||||
|
loss_info['teacher_decoded'] = teacher_decoded
|
||||||
|
loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded
|
||||||
|
loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded
|
||||||
|
|
||||||
|
|
||||||
|
if self.warmed_up:
|
||||||
|
loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded)
|
||||||
|
else:
|
||||||
|
loss_dis = torch.tensor(0.).to(reals)
|
||||||
|
loss_adv = torch.tensor(0.).to(reals)
|
||||||
|
feature_matching_distance = torch.tensor(0.).to(reals)
|
||||||
|
|
||||||
|
loss_info["loss_dis"] = loss_dis
|
||||||
|
loss_info["loss_adv"] = loss_adv
|
||||||
|
loss_info["feature_matching_distance"] = feature_matching_distance
|
||||||
|
|
||||||
|
opt_gen, opt_disc = self.optimizers()
|
||||||
|
|
||||||
|
lr_schedulers = self.lr_schedulers()
|
||||||
|
|
||||||
|
sched_gen = None
|
||||||
|
sched_disc = None
|
||||||
|
|
||||||
|
if lr_schedulers is not None:
|
||||||
|
sched_gen, sched_disc = lr_schedulers
|
||||||
|
|
||||||
|
# Train the discriminator
|
||||||
|
if self.global_step % 2 and self.warmed_up:
|
||||||
|
loss, losses = self.losses_disc(loss_info)
|
||||||
|
|
||||||
|
log_dict = {
|
||||||
|
'train/disc_lr': opt_disc.param_groups[0]['lr']
|
||||||
|
}
|
||||||
|
|
||||||
|
opt_disc.zero_grad()
|
||||||
|
self.manual_backward(loss)
|
||||||
|
opt_disc.step()
|
||||||
|
|
||||||
|
if sched_disc is not None:
|
||||||
|
# sched step every step
|
||||||
|
sched_disc.step()
|
||||||
|
|
||||||
|
# Train the generator
|
||||||
|
else:
|
||||||
|
|
||||||
|
loss, losses = self.losses_gen(loss_info)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.autoencoder_ema.update()
|
||||||
|
|
||||||
|
opt_gen.zero_grad()
|
||||||
|
self.manual_backward(loss)
|
||||||
|
opt_gen.step()
|
||||||
|
|
||||||
|
if sched_gen is not None:
|
||||||
|
# scheduler step every step
|
||||||
|
sched_gen.step()
|
||||||
|
|
||||||
|
log_dict = {
|
||||||
|
'train/loss': loss.detach(),
|
||||||
|
'train/latent_std': latents.std().detach(),
|
||||||
|
'train/data_std': data_std.detach(),
|
||||||
|
'train/gen_lr': opt_gen.param_groups[0]['lr']
|
||||||
|
}
|
||||||
|
|
||||||
|
for loss_name, loss_value in losses.items():
|
||||||
|
log_dict[f'train/{loss_name}'] = loss_value.detach()
|
||||||
|
|
||||||
|
self.log_dict(log_dict, prog_bar=True, on_step=True)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def export_model(self, path, use_safetensors=False):
|
||||||
|
if self.autoencoder_ema is not None:
|
||||||
|
model = self.autoencoder_ema.ema_model
|
||||||
|
else:
|
||||||
|
model = self.autoencoder
|
||||||
|
|
||||||
|
if use_safetensors:
|
||||||
|
save_model(model, path)
|
||||||
|
else:
|
||||||
|
torch.save({"state_dict": model.state_dict()}, path)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderDemoCallback(pl.Callback):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
demo_dl,
|
||||||
|
demo_every=2000,
|
||||||
|
sample_size=65536,
|
||||||
|
sample_rate=48000
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.demo_every = demo_every
|
||||||
|
self.demo_samples = sample_size
|
||||||
|
self.demo_dl = iter(demo_dl)
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.last_demo_step = -1
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
@torch.no_grad()
|
||||||
|
def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
|
||||||
|
if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.last_demo_step = trainer.global_step
|
||||||
|
|
||||||
|
module.eval()
|
||||||
|
|
||||||
|
try:
|
||||||
|
demo_reals, _ = next(self.demo_dl)
|
||||||
|
|
||||||
|
# Remove extra dimension added by WebDataset
|
||||||
|
if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
|
||||||
|
demo_reals = demo_reals[0]
|
||||||
|
|
||||||
|
encoder_input = demo_reals
|
||||||
|
|
||||||
|
encoder_input = encoder_input.to(module.device)
|
||||||
|
|
||||||
|
if module.force_input_mono:
|
||||||
|
encoder_input = encoder_input.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
demo_reals = demo_reals.to(module.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if module.use_ema:
|
||||||
|
|
||||||
|
latents = module.autoencoder_ema.ema_model.encode(encoder_input)
|
||||||
|
|
||||||
|
fakes = module.autoencoder_ema.ema_model.decode(latents)
|
||||||
|
else:
|
||||||
|
latents = module.autoencoder.encode(encoder_input)
|
||||||
|
|
||||||
|
fakes = module.autoencoder.decode(latents)
|
||||||
|
|
||||||
|
#Interleave reals and fakes
|
||||||
|
reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
|
||||||
|
|
||||||
|
# Put the demos together
|
||||||
|
reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
|
||||||
|
|
||||||
|
log_dict = {}
|
||||||
|
|
||||||
|
filename = f'recon_{trainer.global_step:08}.wav'
|
||||||
|
reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
||||||
|
torchaudio.save(filename, reals_fakes, self.sample_rate)
|
||||||
|
|
||||||
|
log_dict[f'recon'] = wandb.Audio(filename,
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
caption=f'Reconstructed')
|
||||||
|
|
||||||
|
log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
|
||||||
|
log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
|
||||||
|
|
||||||
|
log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
|
||||||
|
|
||||||
|
trainer.logger.experiment.log(log_dict)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'{type(e).__name__}: {e}')
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
module.train()
|
||||||
|
|
||||||
|
def create_loss_modules_from_bottleneck(bottleneck, loss_config):
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
|
||||||
|
try:
|
||||||
|
kl_weight = loss_config['bottleneck']['weights']['kl']
|
||||||
|
except:
|
||||||
|
kl_weight = 1e-6
|
||||||
|
|
||||||
|
kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss')
|
||||||
|
losses.append(kl_loss)
|
||||||
|
|
||||||
|
if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
|
||||||
|
quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss')
|
||||||
|
losses.append(quantizer_loss)
|
||||||
|
|
||||||
|
if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck):
|
||||||
|
codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss')
|
||||||
|
commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss')
|
||||||
|
losses.append(codebook_loss)
|
||||||
|
losses.append(commitment_loss)
|
||||||
|
|
||||||
|
if isinstance(bottleneck, WassersteinBottleneck):
|
||||||
|
try:
|
||||||
|
mmd_weight = loss_config['bottleneck']['weights']['mmd']
|
||||||
|
except:
|
||||||
|
mmd_weight = 100
|
||||||
|
|
||||||
|
mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss')
|
||||||
|
losses.append(mmd_loss)
|
||||||
|
|
||||||
|
return losses
|
||||||
1656
stable_audio_tools/training/diffusion.py
Normal file
1656
stable_audio_tools/training/diffusion.py
Normal file
File diff suppressed because it is too large
Load diff
240
stable_audio_tools/training/factory.py
Normal file
240
stable_audio_tools/training/factory.py
Normal file
|
|
@ -0,0 +1,240 @@
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
from ..models.factory import create_model_from_config
|
||||||
|
|
||||||
|
def create_training_wrapper_from_config(model_config, model):
|
||||||
|
model_type = model_config.get('model_type', None)
|
||||||
|
assert model_type is not None, 'model_type must be specified in model config'
|
||||||
|
|
||||||
|
training_config = model_config.get('training', None)
|
||||||
|
assert training_config is not None, 'training config must be specified in model config'
|
||||||
|
|
||||||
|
if model_type == 'autoencoder':
|
||||||
|
from .autoencoders import AutoencoderTrainingWrapper
|
||||||
|
|
||||||
|
ema_copy = None
|
||||||
|
|
||||||
|
if training_config.get("use_ema", False):
|
||||||
|
ema_copy = create_model_from_config(model_config)
|
||||||
|
ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once
|
||||||
|
# Copy each weight to the ema copy
|
||||||
|
for name, param in model.state_dict().items():
|
||||||
|
if isinstance(param, Parameter):
|
||||||
|
# backwards compatibility for serialized parameters
|
||||||
|
param = param.data
|
||||||
|
ema_copy.state_dict()[name].copy_(param)
|
||||||
|
|
||||||
|
use_ema = training_config.get("use_ema", False)
|
||||||
|
|
||||||
|
latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0)
|
||||||
|
|
||||||
|
teacher_model = training_config.get("teacher_model", None)
|
||||||
|
if teacher_model is not None:
|
||||||
|
teacher_model = create_model_from_config(teacher_model)
|
||||||
|
teacher_model = teacher_model.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
teacher_model_ckpt = training_config.get("teacher_model_ckpt", None)
|
||||||
|
if teacher_model_ckpt is not None:
|
||||||
|
teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"])
|
||||||
|
else:
|
||||||
|
raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified")
|
||||||
|
|
||||||
|
return AutoencoderTrainingWrapper(
|
||||||
|
model,
|
||||||
|
lr=training_config["learning_rate"],
|
||||||
|
warmup_steps=training_config.get("warmup_steps", 0),
|
||||||
|
encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False),
|
||||||
|
sample_rate=model_config["sample_rate"],
|
||||||
|
loss_config=training_config.get("loss_configs", None),
|
||||||
|
optimizer_configs=training_config.get("optimizer_configs", None),
|
||||||
|
use_ema=use_ema,
|
||||||
|
ema_copy=ema_copy if use_ema else None,
|
||||||
|
force_input_mono=training_config.get("force_input_mono", False),
|
||||||
|
latent_mask_ratio=latent_mask_ratio,
|
||||||
|
teacher_model=teacher_model
|
||||||
|
)
|
||||||
|
elif model_type == 'diffusion_uncond':
|
||||||
|
from .diffusion import DiffusionUncondTrainingWrapper
|
||||||
|
return DiffusionUncondTrainingWrapper(
|
||||||
|
model,
|
||||||
|
lr=training_config["learning_rate"],
|
||||||
|
pre_encoded=training_config.get("pre_encoded", False),
|
||||||
|
)
|
||||||
|
elif model_type == 'diffusion_cond':
|
||||||
|
from .diffusion import DiffusionCondTrainingWrapper
|
||||||
|
return DiffusionCondTrainingWrapper(
|
||||||
|
model,
|
||||||
|
lr=training_config.get("learning_rate", None),
|
||||||
|
mask_padding=training_config.get("mask_padding", False),
|
||||||
|
mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0),
|
||||||
|
use_ema = training_config.get("use_ema", True),
|
||||||
|
log_loss_info=training_config.get("log_loss_info", False),
|
||||||
|
optimizer_configs=training_config.get("optimizer_configs", None),
|
||||||
|
pre_encoded=training_config.get("pre_encoded", False),
|
||||||
|
cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
|
||||||
|
timestep_sampler = training_config.get("timestep_sampler", "uniform")
|
||||||
|
)
|
||||||
|
elif model_type == 'diffusion_prior':
|
||||||
|
from .diffusion import DiffusionPriorTrainingWrapper
|
||||||
|
from ..models.diffusion_prior import PriorType
|
||||||
|
|
||||||
|
ema_copy = create_model_from_config(model_config)
|
||||||
|
|
||||||
|
# Copy each weight to the ema copy
|
||||||
|
for name, param in model.state_dict().items():
|
||||||
|
if isinstance(param, Parameter):
|
||||||
|
# backwards compatibility for serialized parameters
|
||||||
|
param = param.data
|
||||||
|
ema_copy.state_dict()[name].copy_(param)
|
||||||
|
|
||||||
|
prior_type = training_config.get("prior_type", "mono_stereo")
|
||||||
|
|
||||||
|
if prior_type == "mono_stereo":
|
||||||
|
prior_type_enum = PriorType.MonoToStereo
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prior type: {prior_type}")
|
||||||
|
|
||||||
|
return DiffusionPriorTrainingWrapper(
|
||||||
|
model,
|
||||||
|
lr=training_config["learning_rate"],
|
||||||
|
ema_copy=ema_copy,
|
||||||
|
prior_type=prior_type_enum,
|
||||||
|
log_loss_info=training_config.get("log_loss_info", False),
|
||||||
|
use_reconstruction_loss=training_config.get("use_reconstruction_loss", False),
|
||||||
|
)
|
||||||
|
elif model_type == 'diffusion_cond_inpaint':
|
||||||
|
from .diffusion import DiffusionCondInpaintTrainingWrapper
|
||||||
|
return DiffusionCondInpaintTrainingWrapper(
|
||||||
|
model,
|
||||||
|
lr=training_config.get("learning_rate", None),
|
||||||
|
max_mask_segments = training_config.get("max_mask_segments", 10),
|
||||||
|
log_loss_info=training_config.get("log_loss_info", False),
|
||||||
|
optimizer_configs=training_config.get("optimizer_configs", None),
|
||||||
|
use_ema=training_config.get("use_ema", True),
|
||||||
|
pre_encoded=training_config.get("pre_encoded", False),
|
||||||
|
cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
|
||||||
|
timestep_sampler = training_config.get("timestep_sampler", "uniform")
|
||||||
|
)
|
||||||
|
elif model_type == 'diffusion_autoencoder':
|
||||||
|
from .diffusion import DiffusionAutoencoderTrainingWrapper
|
||||||
|
|
||||||
|
ema_copy = create_model_from_config(model_config)
|
||||||
|
|
||||||
|
# Copy each weight to the ema copy
|
||||||
|
for name, param in model.state_dict().items():
|
||||||
|
if isinstance(param, Parameter):
|
||||||
|
# backwards compatibility for serialized parameters
|
||||||
|
param = param.data
|
||||||
|
ema_copy.state_dict()[name].copy_(param)
|
||||||
|
|
||||||
|
return DiffusionAutoencoderTrainingWrapper(
|
||||||
|
model,
|
||||||
|
ema_copy=ema_copy,
|
||||||
|
lr=training_config["learning_rate"],
|
||||||
|
use_reconstruction_loss=training_config.get("use_reconstruction_loss", False)
|
||||||
|
)
|
||||||
|
elif model_type == 'lm':
|
||||||
|
from .lm import AudioLanguageModelTrainingWrapper
|
||||||
|
|
||||||
|
ema_copy = create_model_from_config(model_config)
|
||||||
|
|
||||||
|
for name, param in model.state_dict().items():
|
||||||
|
if isinstance(param, Parameter):
|
||||||
|
# backwards compatibility for serialized parameters
|
||||||
|
param = param.data
|
||||||
|
ema_copy.state_dict()[name].copy_(param)
|
||||||
|
|
||||||
|
return AudioLanguageModelTrainingWrapper(
|
||||||
|
model,
|
||||||
|
ema_copy=ema_copy,
|
||||||
|
lr=training_config.get("learning_rate", None),
|
||||||
|
use_ema=training_config.get("use_ema", False),
|
||||||
|
optimizer_configs=training_config.get("optimizer_configs", None),
|
||||||
|
pre_encoded=training_config.get("pre_encoded", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||||
|
|
||||||
|
def create_demo_callback_from_config(model_config, **kwargs):
|
||||||
|
model_type = model_config.get('model_type', None)
|
||||||
|
assert model_type is not None, 'model_type must be specified in model config'
|
||||||
|
|
||||||
|
training_config = model_config.get('training', None)
|
||||||
|
assert training_config is not None, 'training config must be specified in model config'
|
||||||
|
|
||||||
|
demo_config = training_config.get("demo", {})
|
||||||
|
|
||||||
|
if model_type == 'autoencoder':
|
||||||
|
from .autoencoders import AutoencoderDemoCallback
|
||||||
|
return AutoencoderDemoCallback(
|
||||||
|
demo_every=demo_config.get("demo_every", 2000),
|
||||||
|
sample_size=model_config["sample_size"],
|
||||||
|
sample_rate=model_config["sample_rate"],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
elif model_type == 'diffusion_uncond':
|
||||||
|
from .diffusion import DiffusionUncondDemoCallback
|
||||||
|
return DiffusionUncondDemoCallback(
|
||||||
|
demo_every=demo_config.get("demo_every", 2000),
|
||||||
|
demo_steps=demo_config.get("demo_steps", 250),
|
||||||
|
sample_rate=model_config["sample_rate"]
|
||||||
|
)
|
||||||
|
elif model_type == "diffusion_autoencoder":
|
||||||
|
from .diffusion import DiffusionAutoencoderDemoCallback
|
||||||
|
return DiffusionAutoencoderDemoCallback(
|
||||||
|
demo_every=demo_config.get("demo_every", 2000),
|
||||||
|
demo_steps=demo_config.get("demo_steps", 250),
|
||||||
|
sample_size=model_config["sample_size"],
|
||||||
|
sample_rate=model_config["sample_rate"],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
elif model_type == "diffusion_prior":
|
||||||
|
from .diffusion import DiffusionPriorDemoCallback
|
||||||
|
return DiffusionPriorDemoCallback(
|
||||||
|
demo_every=demo_config.get("demo_every", 2000),
|
||||||
|
demo_steps=demo_config.get("demo_steps", 250),
|
||||||
|
sample_size=model_config["sample_size"],
|
||||||
|
sample_rate=model_config["sample_rate"],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
elif model_type == "diffusion_cond":
|
||||||
|
from .diffusion import DiffusionCondDemoCallback
|
||||||
|
|
||||||
|
return DiffusionCondDemoCallback(
|
||||||
|
demo_every=demo_config.get("demo_every", 2000),
|
||||||
|
sample_size=model_config["sample_size"],
|
||||||
|
sample_rate=model_config["sample_rate"],
|
||||||
|
demo_steps=demo_config.get("demo_steps", 250),
|
||||||
|
num_demos=demo_config["num_demos"],
|
||||||
|
demo_cfg_scales=demo_config["demo_cfg_scales"],
|
||||||
|
demo_conditioning=demo_config.get("demo_cond", {}),
|
||||||
|
demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False),
|
||||||
|
display_audio_cond=demo_config.get("display_audio_cond", False),
|
||||||
|
)
|
||||||
|
elif model_type == "diffusion_cond_inpaint":
|
||||||
|
from .diffusion import DiffusionCondInpaintDemoCallback
|
||||||
|
|
||||||
|
return DiffusionCondInpaintDemoCallback(
|
||||||
|
demo_every=demo_config.get("demo_every", 2000),
|
||||||
|
sample_size=model_config["sample_size"],
|
||||||
|
sample_rate=model_config["sample_rate"],
|
||||||
|
demo_steps=demo_config.get("demo_steps", 250),
|
||||||
|
demo_cfg_scales=demo_config["demo_cfg_scales"],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == "lm":
|
||||||
|
from .lm import AudioLanguageModelDemoCallback
|
||||||
|
|
||||||
|
return AudioLanguageModelDemoCallback(
|
||||||
|
demo_every=demo_config.get("demo_every", 2000),
|
||||||
|
sample_size=model_config["sample_size"],
|
||||||
|
sample_rate=model_config["sample_rate"],
|
||||||
|
demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]),
|
||||||
|
demo_conditioning=demo_config.get("demo_cond", None),
|
||||||
|
num_demos=demo_config.get("num_demos", 8),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||||
267
stable_audio_tools/training/lm.py
Normal file
267
stable_audio_tools/training/lm.py
Normal file
|
|
@ -0,0 +1,267 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import sys, gc
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import typing as tp
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
|
||||||
|
from ema_pytorch import EMA
|
||||||
|
from einops import rearrange
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from torch import optim
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||||
|
|
||||||
|
from ..models.lm import AudioLanguageModelWrapper
|
||||||
|
from .utils import create_optimizer_from_config, create_scheduler_from_config
|
||||||
|
|
||||||
|
class AudioLanguageModelTrainingWrapper(pl.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: AudioLanguageModelWrapper,
|
||||||
|
lr = 1e-4,
|
||||||
|
use_ema=False,
|
||||||
|
ema_copy=None,
|
||||||
|
optimizer_configs: dict = None,
|
||||||
|
pre_encoded=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.model.pretransform.requires_grad_(False)
|
||||||
|
|
||||||
|
self.model_ema = None
|
||||||
|
if use_ema:
|
||||||
|
self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10)
|
||||||
|
|
||||||
|
assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
|
||||||
|
|
||||||
|
if optimizer_configs is None:
|
||||||
|
optimizer_configs = {
|
||||||
|
"lm": {
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"config": {
|
||||||
|
"lr": lr,
|
||||||
|
"betas": (0.9, 0.95),
|
||||||
|
"weight_decay": 0.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if lr is not None:
|
||||||
|
print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
|
||||||
|
|
||||||
|
self.optimizer_configs = optimizer_configs
|
||||||
|
|
||||||
|
self.pre_encoded = pre_encoded
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lm_opt_config = self.optimizer_configs['lm']
|
||||||
|
opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters())
|
||||||
|
|
||||||
|
if "scheduler" in lm_opt_config:
|
||||||
|
sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm)
|
||||||
|
sched_lm_config = {
|
||||||
|
"scheduler": sched_lm,
|
||||||
|
"interval": "step"
|
||||||
|
}
|
||||||
|
return [opt_lm], [sched_lm_config]
|
||||||
|
|
||||||
|
return [opt_lm]
|
||||||
|
|
||||||
|
# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license
|
||||||
|
# License can be found in LICENSES/LICENSE_META.txt
|
||||||
|
|
||||||
|
def _compute_cross_entropy(
|
||||||
|
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
|
||||||
|
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
|
||||||
|
"""Compute cross entropy between multi-codebook targets and model's logits.
|
||||||
|
The cross entropy is computed per codebook to provide codebook-level cross entropy.
|
||||||
|
Valid timesteps for each of the codebook are pulled from the mask, where invalid
|
||||||
|
timesteps are set to 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
|
||||||
|
targets (torch.Tensor): Target codes, of shape [B, K, T].
|
||||||
|
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
|
||||||
|
Returns:
|
||||||
|
ce (torch.Tensor): Cross entropy averaged over the codebooks
|
||||||
|
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
|
||||||
|
"""
|
||||||
|
B, K, T = targets.shape
|
||||||
|
assert logits.shape[:-1] == targets.shape
|
||||||
|
assert mask.shape == targets.shape
|
||||||
|
ce = torch.zeros([], device=targets.device)
|
||||||
|
ce_per_codebook: tp.List[torch.Tensor] = []
|
||||||
|
for k in range(K):
|
||||||
|
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
|
||||||
|
targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
|
||||||
|
mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
|
||||||
|
ce_targets = targets_k[mask_k]
|
||||||
|
ce_logits = logits_k[mask_k]
|
||||||
|
q_ce = F.cross_entropy(ce_logits, ce_targets)
|
||||||
|
ce += q_ce
|
||||||
|
ce_per_codebook.append(q_ce.detach())
|
||||||
|
# average cross entropy across codebooks
|
||||||
|
ce = ce / K
|
||||||
|
return ce, ce_per_codebook
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
reals, metadata = batch
|
||||||
|
|
||||||
|
if reals.ndim == 4 and reals.shape[0] == 1:
|
||||||
|
reals = reals[0]
|
||||||
|
|
||||||
|
if not self.pre_encoded:
|
||||||
|
codes = self.model.pretransform.tokenize(reals)
|
||||||
|
else:
|
||||||
|
codes = reals
|
||||||
|
|
||||||
|
padding_masks = []
|
||||||
|
for md in metadata:
|
||||||
|
if md["padding_mask"].ndim == 1:
|
||||||
|
padding_masks.append(md["padding_mask"])
|
||||||
|
else:
|
||||||
|
padding_masks.append(md["padding_mask"][0])
|
||||||
|
|
||||||
|
padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length)
|
||||||
|
|
||||||
|
# Interpolate padding masks to the same length as the codes
|
||||||
|
padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool()
|
||||||
|
|
||||||
|
condition_tensors = None
|
||||||
|
|
||||||
|
# If the model is conditioned, get the conditioning tensors
|
||||||
|
if self.model.conditioner is not None:
|
||||||
|
condition_tensors = self.model.conditioner(metadata, self.device)
|
||||||
|
|
||||||
|
lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1)
|
||||||
|
|
||||||
|
logits = lm_output.logits # [b, k, t, c]
|
||||||
|
logits_mask = lm_output.mask # [b, k, t]
|
||||||
|
|
||||||
|
logits_mask = logits_mask & padding_masks
|
||||||
|
|
||||||
|
cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask)
|
||||||
|
|
||||||
|
loss = cross_entropy
|
||||||
|
|
||||||
|
log_dict = {
|
||||||
|
'train/loss': loss.detach(),
|
||||||
|
'train/cross_entropy': cross_entropy.detach(),
|
||||||
|
'train/perplexity': torch.exp(cross_entropy).detach(),
|
||||||
|
'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, ce_q in enumerate(cross_entropy_per_codebook):
|
||||||
|
log_dict[f'cross_entropy_q{k + 1}'] = ce_q
|
||||||
|
log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q)
|
||||||
|
|
||||||
|
self.log_dict(log_dict, prog_bar=True, on_step=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def on_before_zero_grad(self, *args, **kwargs):
|
||||||
|
if self.model_ema is not None:
|
||||||
|
self.model_ema.update()
|
||||||
|
|
||||||
|
def export_model(self, path, use_safetensors=False):
|
||||||
|
|
||||||
|
model = self.model_ema.ema_model if self.model_ema is not None else self.model
|
||||||
|
|
||||||
|
if use_safetensors:
|
||||||
|
save_file(model.state_dict(), path)
|
||||||
|
else:
|
||||||
|
torch.save({"state_dict": model.state_dict()}, path)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLanguageModelDemoCallback(pl.Callback):
|
||||||
|
def __init__(self,
|
||||||
|
demo_every=2000,
|
||||||
|
num_demos=8,
|
||||||
|
sample_size=65536,
|
||||||
|
sample_rate=48000,
|
||||||
|
demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
|
||||||
|
demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.demo_every = demo_every
|
||||||
|
self.num_demos = num_demos
|
||||||
|
self.demo_samples = sample_size
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.last_demo_step = -1
|
||||||
|
self.demo_conditioning = demo_conditioning
|
||||||
|
self.demo_cfg_scales = demo_cfg_scales
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
@torch.no_grad()
|
||||||
|
def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx):
|
||||||
|
|
||||||
|
if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
|
||||||
|
return
|
||||||
|
|
||||||
|
module.eval()
|
||||||
|
|
||||||
|
print(f"Generating demo")
|
||||||
|
self.last_demo_step = trainer.global_step
|
||||||
|
|
||||||
|
demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio
|
||||||
|
|
||||||
|
#demo_reals = batch[0][:self.num_demos]
|
||||||
|
|
||||||
|
# if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
|
||||||
|
# demo_reals = demo_reals[0]
|
||||||
|
|
||||||
|
#demo_reals_tokens = module.model.pretransform.tokenize(demo_reals)
|
||||||
|
|
||||||
|
##Limit to first 50 tokens
|
||||||
|
#demo_reals_tokens = demo_reals_tokens[:, :, :50]
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Getting conditioning")
|
||||||
|
|
||||||
|
for cfg_scale in self.demo_cfg_scales:
|
||||||
|
|
||||||
|
model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model
|
||||||
|
|
||||||
|
print(f"Generating demo for cfg scale {cfg_scale}")
|
||||||
|
fakes = model.generate_audio(
|
||||||
|
batch_size=self.num_demos,
|
||||||
|
max_gen_len=demo_length_tokens,
|
||||||
|
conditioning=self.demo_conditioning,
|
||||||
|
#init_data = demo_reals_tokens,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
temp=1.0,
|
||||||
|
top_p=0.95
|
||||||
|
)
|
||||||
|
|
||||||
|
# Put the demos together
|
||||||
|
fakes = rearrange(fakes, 'b d n -> d (b n)')
|
||||||
|
|
||||||
|
log_dict = {}
|
||||||
|
|
||||||
|
filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
|
||||||
|
fakes = fakes / fakes.abs().max()
|
||||||
|
fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu()
|
||||||
|
torchaudio.save(filename, fakes, self.sample_rate)
|
||||||
|
|
||||||
|
log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
caption=f'Reconstructed')
|
||||||
|
|
||||||
|
log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
|
||||||
|
|
||||||
|
trainer.logger.experiment.log(log_dict)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
module.train()
|
||||||
1
stable_audio_tools/training/losses/__init__.py
Normal file
1
stable_audio_tools/training/losses/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .losses import *
|
||||||
607
stable_audio_tools/training/losses/auraloss.py
Normal file
607
stable_audio_tools/training/losses/auraloss.py
Normal file
|
|
@ -0,0 +1,607 @@
|
||||||
|
# Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0
|
||||||
|
# You can find the license at LICENSES/LICENSE_AURALOSS.txt
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Any
|
||||||
|
import scipy.signal
|
||||||
|
|
||||||
|
def apply_reduction(losses, reduction="none"):
|
||||||
|
"""Apply reduction to collection of losses."""
|
||||||
|
if reduction == "mean":
|
||||||
|
losses = losses.mean()
|
||||||
|
elif reduction == "sum":
|
||||||
|
losses = losses.sum()
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def get_window(win_type: str, win_length: int):
|
||||||
|
"""Return a window function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
win_type (str): Window type. Can either be one of the window function provided in PyTorch
|
||||||
|
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
|
||||||
|
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
|
||||||
|
win_length (int): Window length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
win: The window as a 1D torch tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
win = getattr(torch, win_type)(win_length)
|
||||||
|
except:
|
||||||
|
win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length))
|
||||||
|
|
||||||
|
return win
|
||||||
|
|
||||||
|
class SumAndDifference(torch.nn.Module):
|
||||||
|
"""Sum and difference signal extraction module."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize sum and difference extraction module."""
|
||||||
|
super(SumAndDifference, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Predicted signal (B, #channels, #samples).
|
||||||
|
Returns:
|
||||||
|
Tensor: Sum signal.
|
||||||
|
Tensor: Difference signal.
|
||||||
|
"""
|
||||||
|
if not (x.size(1) == 2): # inputs must be stereo
|
||||||
|
raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).")
|
||||||
|
|
||||||
|
sum_sig = self.sum(x).unsqueeze(1)
|
||||||
|
diff_sig = self.diff(x).unsqueeze(1)
|
||||||
|
|
||||||
|
return sum_sig, diff_sig
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sum(x):
|
||||||
|
return x[:, 0, :] + x[:, 1, :]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def diff(x):
|
||||||
|
return x[:, 0, :] - x[:, 1, :]
|
||||||
|
|
||||||
|
|
||||||
|
class FIRFilter(torch.nn.Module):
|
||||||
|
"""FIR pre-emphasis filtering module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp"
|
||||||
|
coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85
|
||||||
|
ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101
|
||||||
|
plot (bool): Plot the magnitude respond of the filter. Default: False
|
||||||
|
|
||||||
|
Based upon the perceptual loss pre-empahsis filters proposed by
|
||||||
|
[Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922).
|
||||||
|
|
||||||
|
A-weighting filter - "aw"
|
||||||
|
First-order highpass - "hp"
|
||||||
|
Folded differentiator - "fd"
|
||||||
|
|
||||||
|
Note that the default coefficeint value of 0.85 is optimized for
|
||||||
|
a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False):
|
||||||
|
"""Initilize FIR pre-emphasis filtering module."""
|
||||||
|
super(FIRFilter, self).__init__()
|
||||||
|
self.filter_type = filter_type
|
||||||
|
self.coef = coef
|
||||||
|
self.fs = fs
|
||||||
|
self.ntaps = ntaps
|
||||||
|
self.plot = plot
|
||||||
|
|
||||||
|
import scipy.signal
|
||||||
|
|
||||||
|
if ntaps % 2 == 0:
|
||||||
|
raise ValueError(f"ntaps must be odd (ntaps={ntaps}).")
|
||||||
|
|
||||||
|
if filter_type == "hp":
|
||||||
|
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
|
||||||
|
self.fir.weight.requires_grad = False
|
||||||
|
self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1)
|
||||||
|
elif filter_type == "fd":
|
||||||
|
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
|
||||||
|
self.fir.weight.requires_grad = False
|
||||||
|
self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1)
|
||||||
|
elif filter_type == "aw":
|
||||||
|
# Definition of analog A-weighting filter according to IEC/CD 1672.
|
||||||
|
f1 = 20.598997
|
||||||
|
f2 = 107.65265
|
||||||
|
f3 = 737.86223
|
||||||
|
f4 = 12194.217
|
||||||
|
A1000 = 1.9997
|
||||||
|
|
||||||
|
NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0]
|
||||||
|
DENs = np.polymul(
|
||||||
|
[1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2],
|
||||||
|
[1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2],
|
||||||
|
)
|
||||||
|
DENs = np.polymul(
|
||||||
|
np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2]
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert analog filter to digital filter
|
||||||
|
b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs)
|
||||||
|
|
||||||
|
# compute the digital filter frequency response
|
||||||
|
w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs)
|
||||||
|
|
||||||
|
# then we fit to 101 tap FIR filter with least squares
|
||||||
|
taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs)
|
||||||
|
|
||||||
|
# now implement this digital FIR filter as a Conv1d layer
|
||||||
|
self.fir = torch.nn.Conv1d(
|
||||||
|
1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2
|
||||||
|
)
|
||||||
|
self.fir.weight.requires_grad = False
|
||||||
|
self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1)
|
||||||
|
|
||||||
|
if plot:
|
||||||
|
from .plotting import compare_filters
|
||||||
|
compare_filters(b, a, taps, fs=fs)
|
||||||
|
|
||||||
|
def forward(self, input, target):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
input (Tensor): Predicted signal (B, #channels, #samples).
|
||||||
|
target (Tensor): Groundtruth signal (B, #channels, #samples).
|
||||||
|
Returns:
|
||||||
|
Tensor: Filtered signal.
|
||||||
|
"""
|
||||||
|
input = torch.nn.functional.conv1d(
|
||||||
|
input, self.fir.weight.data, padding=self.ntaps // 2
|
||||||
|
)
|
||||||
|
target = torch.nn.functional.conv1d(
|
||||||
|
target, self.fir.weight.data, padding=self.ntaps // 2
|
||||||
|
)
|
||||||
|
return input, target
|
||||||
|
|
||||||
|
class SpectralConvergenceLoss(torch.nn.Module):
|
||||||
|
"""Spectral convergence loss module.
|
||||||
|
|
||||||
|
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(SpectralConvergenceLoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x_mag, y_mag):
|
||||||
|
return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean()
|
||||||
|
|
||||||
|
class STFTMagnitudeLoss(torch.nn.Module):
|
||||||
|
"""STFT magnitude loss module.
|
||||||
|
|
||||||
|
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
|
||||||
|
and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1)
|
||||||
|
|
||||||
|
Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the
|
||||||
|
compression strength (larger value results in more compression), and `log_eps` can be used
|
||||||
|
to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive
|
||||||
|
output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log (bool, optional): Log-scale the STFT magnitudes,
|
||||||
|
or use linear scale. Default: True
|
||||||
|
log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm.
|
||||||
|
Default: 0.0
|
||||||
|
log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm.
|
||||||
|
Default: 1.0
|
||||||
|
distance (str, optional): Distance function ["L1", "L2"]. Default: "L1"
|
||||||
|
reduction (str, optional): Reduction of the loss elements. Default: "mean"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"):
|
||||||
|
super(STFTMagnitudeLoss, self).__init__()
|
||||||
|
|
||||||
|
self.log = log
|
||||||
|
self.log_eps = log_eps
|
||||||
|
self.log_fac = log_fac
|
||||||
|
|
||||||
|
if distance == "L1":
|
||||||
|
self.distance = torch.nn.L1Loss(reduction=reduction)
|
||||||
|
elif distance == "L2":
|
||||||
|
self.distance = torch.nn.MSELoss(reduction=reduction)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid distance: '{distance}'.")
|
||||||
|
|
||||||
|
def forward(self, x_mag, y_mag):
|
||||||
|
if self.log:
|
||||||
|
x_mag = torch.log(self.log_fac * x_mag + self.log_eps)
|
||||||
|
y_mag = torch.log(self.log_fac * y_mag + self.log_eps)
|
||||||
|
return self.distance(x_mag, y_mag)
|
||||||
|
|
||||||
|
|
||||||
|
class STFTLoss(torch.nn.Module):
|
||||||
|
"""STFT loss module.
|
||||||
|
|
||||||
|
See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fft_size (int, optional): FFT size in samples. Default: 1024
|
||||||
|
hop_size (int, optional): Hop size of the FFT in samples. Default: 256
|
||||||
|
win_length (int, optional): Length of the FFT analysis window. Default: 1024
|
||||||
|
window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch
|
||||||
|
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
|
||||||
|
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
|
||||||
|
Default: 'hann_window'
|
||||||
|
w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
|
||||||
|
w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
|
||||||
|
w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
|
||||||
|
w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
|
||||||
|
sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
|
||||||
|
scale (str, optional): Optional frequency scaling method, options include:
|
||||||
|
['mel', 'chroma']
|
||||||
|
Default: None
|
||||||
|
n_bins (int, optional): Number of scaling frequency bins. Default: None.
|
||||||
|
perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
|
||||||
|
scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
|
||||||
|
eps (float, optional): Small epsilon value for stablity. Default: 1e-8
|
||||||
|
output (str, optional): Format of the loss returned.
|
||||||
|
'loss' : Return only the raw, aggregate loss term.
|
||||||
|
'full' : Return the raw loss, plus intermediate loss terms.
|
||||||
|
Default: 'loss'
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
'none': no reduction will be applied,
|
||||||
|
'mean': the sum of the output will be divided by the number of elements in the output,
|
||||||
|
'sum': the output will be summed.
|
||||||
|
Default: 'mean'
|
||||||
|
mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms.
|
||||||
|
device (str, optional): Place the filterbanks on specified device. Default: None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loss:
|
||||||
|
Aggreate loss term. Only returned if output='loss'. By default.
|
||||||
|
loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss:
|
||||||
|
Aggregate and intermediate loss terms. Only returned if output='full'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fft_size: int = 1024,
|
||||||
|
hop_size: int = 256,
|
||||||
|
win_length: int = 1024,
|
||||||
|
window: str = "hann_window",
|
||||||
|
w_sc: float = 1.0,
|
||||||
|
w_log_mag: float = 1.0,
|
||||||
|
w_lin_mag: float = 0.0,
|
||||||
|
w_phs: float = 0.0,
|
||||||
|
sample_rate: float = None,
|
||||||
|
scale: str = None,
|
||||||
|
n_bins: int = None,
|
||||||
|
perceptual_weighting: bool = False,
|
||||||
|
scale_invariance: bool = False,
|
||||||
|
eps: float = 1e-8,
|
||||||
|
output: str = "loss",
|
||||||
|
reduction: str = "mean",
|
||||||
|
mag_distance: str = "L1",
|
||||||
|
device: Any = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.fft_size = fft_size
|
||||||
|
self.hop_size = hop_size
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = get_window(window, win_length)
|
||||||
|
self.w_sc = w_sc
|
||||||
|
self.w_log_mag = w_log_mag
|
||||||
|
self.w_lin_mag = w_lin_mag
|
||||||
|
self.w_phs = w_phs
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.scale = scale
|
||||||
|
self.n_bins = n_bins
|
||||||
|
self.perceptual_weighting = perceptual_weighting
|
||||||
|
self.scale_invariance = scale_invariance
|
||||||
|
self.eps = eps
|
||||||
|
self.output = output
|
||||||
|
self.reduction = reduction
|
||||||
|
self.mag_distance = mag_distance
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.phs_used = bool(self.w_phs)
|
||||||
|
|
||||||
|
self.spectralconv = SpectralConvergenceLoss()
|
||||||
|
self.logstft = STFTMagnitudeLoss(
|
||||||
|
log=True,
|
||||||
|
reduction=reduction,
|
||||||
|
distance=mag_distance,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
self.linstft = STFTMagnitudeLoss(
|
||||||
|
log=False,
|
||||||
|
reduction=reduction,
|
||||||
|
distance=mag_distance,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup mel filterbank
|
||||||
|
if scale is not None:
|
||||||
|
try:
|
||||||
|
import librosa.filters
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print("Try `pip install auraloss[all]`.")
|
||||||
|
|
||||||
|
if self.scale == "mel":
|
||||||
|
assert sample_rate != None # Must set sample rate to use mel scale
|
||||||
|
assert n_bins <= fft_size # Must be more FFT bins than Mel bins
|
||||||
|
fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
|
||||||
|
fb = torch.tensor(fb).unsqueeze(0)
|
||||||
|
|
||||||
|
elif self.scale == "chroma":
|
||||||
|
assert sample_rate != None # Must set sample rate to use chroma scale
|
||||||
|
assert n_bins <= fft_size # Must be more FFT bins than chroma bins
|
||||||
|
fb = librosa.filters.chroma(
|
||||||
|
sr=sample_rate, n_fft=fft_size, n_chroma=n_bins
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("fb", fb)
|
||||||
|
|
||||||
|
if scale is not None and device is not None:
|
||||||
|
self.fb = self.fb.to(self.device) # move filterbank to device
|
||||||
|
|
||||||
|
if self.perceptual_weighting:
|
||||||
|
if sample_rate is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"`sample_rate` must be supplied when `perceptual_weighting = True`."
|
||||||
|
)
|
||||||
|
self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate)
|
||||||
|
|
||||||
|
def stft(self, x):
|
||||||
|
"""Perform STFT.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input signal tensor (B, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: x_mag, x_phs
|
||||||
|
Magnitude and phase spectra (B, fft_size // 2 + 1, frames).
|
||||||
|
"""
|
||||||
|
x_stft = torch.stft(
|
||||||
|
x,
|
||||||
|
self.fft_size,
|
||||||
|
self.hop_size,
|
||||||
|
self.win_length,
|
||||||
|
self.window,
|
||||||
|
return_complex=True,
|
||||||
|
)
|
||||||
|
x_mag = torch.sqrt(
|
||||||
|
torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)
|
||||||
|
)
|
||||||
|
|
||||||
|
# torch.angle is expensive, so it is only evaluated if the values are used in the loss
|
||||||
|
if self.phs_used:
|
||||||
|
x_phs = torch.angle(x_stft)
|
||||||
|
else:
|
||||||
|
x_phs = None
|
||||||
|
|
||||||
|
return x_mag, x_phs
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
||||||
|
bs, chs, seq_len = input.size()
|
||||||
|
|
||||||
|
if self.perceptual_weighting: # apply optional A-weighting via FIR filter
|
||||||
|
# since FIRFilter only support mono audio we will move channels to batch dim
|
||||||
|
input = input.view(bs * chs, 1, -1)
|
||||||
|
target = target.view(bs * chs, 1, -1)
|
||||||
|
|
||||||
|
# now apply the filter to both
|
||||||
|
self.prefilter.to(input.device)
|
||||||
|
input, target = self.prefilter(input, target)
|
||||||
|
|
||||||
|
# now move the channels back
|
||||||
|
input = input.view(bs, chs, -1)
|
||||||
|
target = target.view(bs, chs, -1)
|
||||||
|
|
||||||
|
# compute the magnitude and phase spectra of input and target
|
||||||
|
self.window = self.window.to(input.device)
|
||||||
|
|
||||||
|
x_mag, x_phs = self.stft(input.view(-1, input.size(-1)))
|
||||||
|
y_mag, y_phs = self.stft(target.view(-1, target.size(-1)))
|
||||||
|
|
||||||
|
# apply relevant transforms
|
||||||
|
if self.scale is not None:
|
||||||
|
self.fb = self.fb.to(input.device)
|
||||||
|
x_mag = torch.matmul(self.fb, x_mag)
|
||||||
|
y_mag = torch.matmul(self.fb, y_mag)
|
||||||
|
|
||||||
|
# normalize scales
|
||||||
|
if self.scale_invariance:
|
||||||
|
alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1]))
|
||||||
|
y_mag = y_mag * alpha.unsqueeze(-1)
|
||||||
|
|
||||||
|
# compute loss terms
|
||||||
|
sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
|
||||||
|
log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
|
||||||
|
lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0
|
||||||
|
phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0
|
||||||
|
|
||||||
|
# combine loss terms
|
||||||
|
loss = (
|
||||||
|
(self.w_sc * sc_mag_loss)
|
||||||
|
+ (self.w_log_mag * log_mag_loss)
|
||||||
|
+ (self.w_lin_mag * lin_mag_loss)
|
||||||
|
+ (self.w_phs * phs_loss)
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = apply_reduction(loss, reduction=self.reduction)
|
||||||
|
|
||||||
|
if self.output == "loss":
|
||||||
|
return loss
|
||||||
|
elif self.output == "full":
|
||||||
|
return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
|
||||||
|
|
||||||
|
class MultiResolutionSTFTLoss(torch.nn.Module):
|
||||||
|
"""Multi resolution STFT loss module.
|
||||||
|
|
||||||
|
See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fft_sizes (list): List of FFT sizes.
|
||||||
|
hop_sizes (list): List of hop sizes.
|
||||||
|
win_lengths (list): List of window lengths.
|
||||||
|
window (str, optional): Window to apply before FFT, options include:
|
||||||
|
'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
|
||||||
|
Default: 'hann_window'
|
||||||
|
w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
|
||||||
|
w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
|
||||||
|
w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
|
||||||
|
w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
|
||||||
|
sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
|
||||||
|
scale (str, optional): Optional frequency scaling method, options include:
|
||||||
|
['mel', 'chroma']
|
||||||
|
Default: None
|
||||||
|
n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None.
|
||||||
|
scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fft_sizes: List[int] = [1024, 2048, 512],
|
||||||
|
hop_sizes: List[int] = [120, 240, 50],
|
||||||
|
win_lengths: List[int] = [600, 1200, 240],
|
||||||
|
window: str = "hann_window",
|
||||||
|
w_sc: float = 1.0,
|
||||||
|
w_log_mag: float = 1.0,
|
||||||
|
w_lin_mag: float = 0.0,
|
||||||
|
w_phs: float = 0.0,
|
||||||
|
sample_rate: float = None,
|
||||||
|
scale: str = None,
|
||||||
|
n_bins: int = None,
|
||||||
|
perceptual_weighting: bool = False,
|
||||||
|
scale_invariance: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all
|
||||||
|
self.fft_sizes = fft_sizes
|
||||||
|
self.hop_sizes = hop_sizes
|
||||||
|
self.win_lengths = win_lengths
|
||||||
|
|
||||||
|
self.stft_losses = torch.nn.ModuleList()
|
||||||
|
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
||||||
|
self.stft_losses += [
|
||||||
|
STFTLoss(
|
||||||
|
fs,
|
||||||
|
ss,
|
||||||
|
wl,
|
||||||
|
window,
|
||||||
|
w_sc,
|
||||||
|
w_log_mag,
|
||||||
|
w_lin_mag,
|
||||||
|
w_phs,
|
||||||
|
sample_rate,
|
||||||
|
scale,
|
||||||
|
n_bins,
|
||||||
|
perceptual_weighting,
|
||||||
|
scale_invariance,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
mrstft_loss = 0.0
|
||||||
|
sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], []
|
||||||
|
|
||||||
|
for f in self.stft_losses:
|
||||||
|
if f.output == "full": # extract just first term
|
||||||
|
tmp_loss = f(x, y)
|
||||||
|
mrstft_loss += tmp_loss[0]
|
||||||
|
sc_mag_loss.append(tmp_loss[1])
|
||||||
|
log_mag_loss.append(tmp_loss[2])
|
||||||
|
lin_mag_loss.append(tmp_loss[3])
|
||||||
|
phs_loss.append(tmp_loss[4])
|
||||||
|
else:
|
||||||
|
mrstft_loss += f(x, y)
|
||||||
|
|
||||||
|
mrstft_loss /= len(self.stft_losses)
|
||||||
|
|
||||||
|
if f.output == "loss":
|
||||||
|
return mrstft_loss
|
||||||
|
else:
|
||||||
|
return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
|
||||||
|
|
||||||
|
|
||||||
|
class SumAndDifferenceSTFTLoss(torch.nn.Module):
|
||||||
|
"""Sum and difference sttereo STFT loss module.
|
||||||
|
|
||||||
|
See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fft_sizes (List[int]): List of FFT sizes.
|
||||||
|
hop_sizes (List[int]): List of hop sizes.
|
||||||
|
win_lengths (List[int]): List of window lengths.
|
||||||
|
window (str, optional): Window function type.
|
||||||
|
w_sum (float, optional): Weight of the sum loss component. Default: 1.0
|
||||||
|
w_diff (float, optional): Weight of the difference loss component. Default: 1.0
|
||||||
|
perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
|
||||||
|
mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False
|
||||||
|
n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128
|
||||||
|
sample_rate (float, optional): Audio sample rate. Default: None
|
||||||
|
output (str, optional): Format of the loss returned.
|
||||||
|
'loss' : Return only the raw, aggregate loss term.
|
||||||
|
'full' : Return the raw loss, plus intermediate loss terms.
|
||||||
|
Default: 'loss'
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fft_sizes: List[int],
|
||||||
|
hop_sizes: List[int],
|
||||||
|
win_lengths: List[int],
|
||||||
|
window: str = "hann_window",
|
||||||
|
w_sum: float = 1.0,
|
||||||
|
w_diff: float = 1.0,
|
||||||
|
output: str = "loss",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.sd = SumAndDifference()
|
||||||
|
self.w_sum = w_sum
|
||||||
|
self.w_diff = w_diff
|
||||||
|
self.output = output
|
||||||
|
self.mrstft = MultiResolutionSTFTLoss(
|
||||||
|
fft_sizes,
|
||||||
|
hop_sizes,
|
||||||
|
win_lengths,
|
||||||
|
window,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
||||||
|
"""This loss function assumes batched input of stereo audio in the time domain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len).
|
||||||
|
target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'.
|
||||||
|
loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor):
|
||||||
|
Aggregate and intermediate loss terms. Only returned if output='full'.
|
||||||
|
"""
|
||||||
|
assert input.shape == target.shape # must have same shape
|
||||||
|
bs, chs, seq_len = input.size()
|
||||||
|
|
||||||
|
# compute sum and difference signals for both
|
||||||
|
input_sum, input_diff = self.sd(input)
|
||||||
|
target_sum, target_diff = self.sd(target)
|
||||||
|
|
||||||
|
# compute error in STFT domain
|
||||||
|
sum_loss = self.mrstft(input_sum, target_sum)
|
||||||
|
diff_loss = self.mrstft(input_diff, target_diff)
|
||||||
|
loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2
|
||||||
|
|
||||||
|
if self.output == "loss":
|
||||||
|
return loss
|
||||||
|
elif self.output == "full":
|
||||||
|
return loss, sum_loss, diff_loss
|
||||||
101
stable_audio_tools/training/losses/losses.py
Normal file
101
stable_audio_tools/training/losses/losses.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
class LossModule(nn.Module):
|
||||||
|
def __init__(self, name: str, weight: float = 1.0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.name = name
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
def forward(self, info, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class ValueLoss(LossModule):
|
||||||
|
def __init__(self, key: str, name, weight: float = 1.0):
|
||||||
|
super().__init__(name=name, weight=weight)
|
||||||
|
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def forward(self, info):
|
||||||
|
return self.weight * info[self.key]
|
||||||
|
|
||||||
|
class L1Loss(LossModule):
|
||||||
|
def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'):
|
||||||
|
super().__init__(name=name, weight=weight)
|
||||||
|
|
||||||
|
self.key_a = key_a
|
||||||
|
self.key_b = key_b
|
||||||
|
|
||||||
|
self.mask_key = mask_key
|
||||||
|
|
||||||
|
def forward(self, info):
|
||||||
|
mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none')
|
||||||
|
|
||||||
|
if self.mask_key is not None and self.mask_key in info:
|
||||||
|
mse_loss = mse_loss[info[self.mask_key]]
|
||||||
|
|
||||||
|
mse_loss = mse_loss.mean()
|
||||||
|
|
||||||
|
return self.weight * mse_loss
|
||||||
|
|
||||||
|
class MSELoss(LossModule):
|
||||||
|
def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'):
|
||||||
|
super().__init__(name=name, weight=weight)
|
||||||
|
|
||||||
|
self.key_a = key_a
|
||||||
|
self.key_b = key_b
|
||||||
|
|
||||||
|
self.mask_key = mask_key
|
||||||
|
|
||||||
|
def forward(self, info):
|
||||||
|
mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none')
|
||||||
|
|
||||||
|
if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None:
|
||||||
|
mask = info[self.mask_key]
|
||||||
|
|
||||||
|
if mask.ndim == 2 and mse_loss.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
if mask.shape[1] != mse_loss.shape[1]:
|
||||||
|
mask = mask.repeat(1, mse_loss.shape[1], 1)
|
||||||
|
|
||||||
|
mse_loss = mse_loss[mask]
|
||||||
|
|
||||||
|
mse_loss = mse_loss.mean()
|
||||||
|
|
||||||
|
return self.weight * mse_loss
|
||||||
|
|
||||||
|
class AuralossLoss(LossModule):
|
||||||
|
def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1):
|
||||||
|
super().__init__(name, weight)
|
||||||
|
|
||||||
|
self.auraloss_module = auraloss_module
|
||||||
|
|
||||||
|
self.input_key = input_key
|
||||||
|
self.target_key = target_key
|
||||||
|
|
||||||
|
def forward(self, info):
|
||||||
|
loss = self.auraloss_module(info[self.input_key], info[self.target_key])
|
||||||
|
|
||||||
|
return self.weight * loss
|
||||||
|
|
||||||
|
class MultiLoss(nn.Module):
|
||||||
|
def __init__(self, losses: tp.List[LossModule]):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.losses = nn.ModuleList(losses)
|
||||||
|
|
||||||
|
def forward(self, info):
|
||||||
|
total_loss = 0
|
||||||
|
|
||||||
|
losses = {}
|
||||||
|
|
||||||
|
for loss_module in self.losses:
|
||||||
|
module_loss = loss_module(info)
|
||||||
|
total_loss += module_loss
|
||||||
|
losses[loss_module.name] = module_loss
|
||||||
|
|
||||||
|
return total_loss, losses
|
||||||
111
stable_audio_tools/training/utils.py
Normal file
111
stable_audio_tools/training/utils.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
"""Get rank of current process."""
|
||||||
|
|
||||||
|
print(os.environ.keys())
|
||||||
|
|
||||||
|
if "SLURM_PROCID" in os.environ:
|
||||||
|
return int(os.environ["SLURM_PROCID"])
|
||||||
|
|
||||||
|
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return torch.distributed.get_rank()
|
||||||
|
|
||||||
|
class InverseLR(torch.optim.lr_scheduler._LRScheduler):
|
||||||
|
"""Implements an inverse decay learning rate schedule with an optional exponential
|
||||||
|
warmup. When last_epoch=-1, sets initial lr as lr.
|
||||||
|
inv_gamma is the number of steps/epochs required for the learning rate to decay to
|
||||||
|
(1 / 2)**power of its original value.
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
|
||||||
|
power (float): Exponential factor of learning rate decay. Default: 1.
|
||||||
|
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||||
|
Default: 0.
|
||||||
|
final_lr (float): The final learning rate. Default: 0.
|
||||||
|
last_epoch (int): The index of last epoch. Default: -1.
|
||||||
|
verbose (bool): If ``True``, prints a message to stdout for
|
||||||
|
each update. Default: ``False``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0.,
|
||||||
|
last_epoch=-1, verbose=False):
|
||||||
|
self.inv_gamma = inv_gamma
|
||||||
|
self.power = power
|
||||||
|
if not 0. <= warmup < 1:
|
||||||
|
raise ValueError('Invalid value for warmup')
|
||||||
|
self.warmup = warmup
|
||||||
|
self.final_lr = final_lr
|
||||||
|
super().__init__(optimizer, last_epoch, verbose)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
if not self._get_lr_called_within_step:
|
||||||
|
import warnings
|
||||||
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||||
|
"please use `get_last_lr()`.")
|
||||||
|
|
||||||
|
return self._get_closed_form_lr()
|
||||||
|
|
||||||
|
def _get_closed_form_lr(self):
|
||||||
|
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||||
|
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
|
||||||
|
return [warmup * max(self.final_lr, base_lr * lr_mult)
|
||||||
|
for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
def copy_state_dict(model, state_dict):
|
||||||
|
"""Load state_dict to model, but only for keys that match exactly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): model to load state_dict.
|
||||||
|
state_dict (OrderedDict): state_dict to load.
|
||||||
|
"""
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
for key in state_dict:
|
||||||
|
if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape:
|
||||||
|
if isinstance(state_dict[key], torch.nn.Parameter):
|
||||||
|
# backwards compatibility for serialized parameters
|
||||||
|
state_dict[key] = state_dict[key].data
|
||||||
|
model_state_dict[key] = state_dict[key]
|
||||||
|
|
||||||
|
model.load_state_dict(model_state_dict, strict=False)
|
||||||
|
|
||||||
|
def create_optimizer_from_config(optimizer_config, parameters):
|
||||||
|
"""Create optimizer from config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameters (iterable): parameters to optimize.
|
||||||
|
optimizer_config (dict): optimizer config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.optim.Optimizer: optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
optimizer_type = optimizer_config["type"]
|
||||||
|
|
||||||
|
if optimizer_type == "FusedAdam":
|
||||||
|
from deepspeed.ops.adam import FusedAdam
|
||||||
|
optimizer = FusedAdam(parameters, **optimizer_config["config"])
|
||||||
|
else:
|
||||||
|
optimizer_fn = getattr(torch.optim, optimizer_type)
|
||||||
|
optimizer = optimizer_fn(parameters, **optimizer_config["config"])
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def create_scheduler_from_config(scheduler_config, optimizer):
|
||||||
|
"""Create scheduler from config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_config (dict): scheduler config.
|
||||||
|
optimizer (torch.optim.Optimizer): optimizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.optim.lr_scheduler._LRScheduler: scheduler.
|
||||||
|
"""
|
||||||
|
if scheduler_config["type"] == "InverseLR":
|
||||||
|
scheduler_fn = InverseLR
|
||||||
|
else:
|
||||||
|
scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"])
|
||||||
|
scheduler = scheduler_fn(optimizer, **scheduler_config["config"])
|
||||||
|
return scheduler
|
||||||
Loading…
Reference in a new issue