diff --git a/.gitignore b/.gitignore index 3378883..a380d64 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,178 @@ -# /static/videos/*.mp4 -# /static/videos/*.mov \ No newline at end of file +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + + + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +*.ckpt +*.wav +# *.mp4 +*.mp3 +*.jsonl +wandb/* + + + + +model/ +logs/ +log/ +saved_ckpt/ +wandb/ +data/ +demo_result/ +model/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..43b0654 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2023 Stability AI +Copyright (c) 2025 AudioX, HKUST + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 5182de4..717aaa3 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,20 @@ -# AudioX: Diffusion Transformer for Anything-to-Audio Generation +# 🎧 AudioX: Diffusion Transformer for Anything-to-Audio Generation + +[![arXiv](https://img.shields.io/badge/arXiv-2503.10522-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2503.10522) +[![Project Page](https://img.shields.io/badge/GitHub.io-Project-blue?logo=Github&style=flat-square)](https://zeyuet.github.io/AudioX/) +[![🤗 Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/HKUSTAudio/AudioX) -[![arXiv](https://img.shields.io/badge/arXiv-2503.10522-brightgreen.svg?style=flat-square)](https://arxiv.org/pdf/2503.10522) [![githubio](https://img.shields.io/badge/GitHub.io-Project-blue?logo=Github&style=flat-square)](https://zeyuet.github.io/AudioX/) +--- + +**This is the official repository for "[AudioX: Diffusion Transformer for Anything-to-Audio Generation](https://arxiv.org/pdf/2503.10522)".** -**This is the repository for "AudioX: Diffusion Transformer for Anything-to-Audio Generation".** ## 📺 Demo Video https://github.com/user-attachments/assets/0d8dd927-ff0f-4b35-ab1f-b3c3915017be +--- ## ✨ Abstract @@ -34,8 +40,135 @@ Audio and music generation have emerged as crucial tasks in many applications, y ## Code -To be released. -
+### 🛠️ Environment Setup +```bash +git clone https://github.com/ZeyueT/AudioX.git +cd AudioX +conda create -n AudioX python=3.8.20 +conda activate AudioX +pip install git+https://github.com/ZeyueT/AudioX.git +conda install -c conda-forge ffmpeg libsndfile + +``` + +## 🪄 Pretrained Checkpoints + +Download the pretrained model from 🤗 [AudioX on Hugging Face](https://huggingface.co/HKUSTAudio/AudioX): + +```bash +mkdir -p model +wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/model.ckpt -O model/model.ckpt +wget https://huggingface.co/HKUSTAudio/AudioX/resolve/main/config.json -O model/config.json +``` + +### 🤗 Gradio Demo + +To launch the Gradio demo locally, run: + +```bash +python3 run_gradio.py \ + --model-config model/config.json \ + --share +``` + + +### 🎯 Prompt Configuration Examples + +| Task | `video_path` | `text_prompt` | `audio_path` | +|:---------------------|:-------------------|:----------------------------------------------|:-------------| +| Text-to-Audio (T2A) | `None` | `"Typing on a keyboard"` | `None` | +| Text-to-Music (T2M) | `None` | `"A music with piano and violin"` | `None` | +| Video-to-Audio (V2A) | `"video_path.mp4"` | `"Generate general audio for the video"` | `None` | +| Video-to-Music (V2M) | `"video_path.mp4"` | `"Generate music for the video"` | `None` | +| TV-to-Audio (TV2A) | `"video_path.mp4"` | `"Ocean waves crashing with people laughing"` | `None` | +| TV-to-Music (TV2M) | `"video_path.mp4"` | `"Generate music with piano instrument"` | `None` | + +### 🖥️ Script Inference + +```python +import torch +import torchaudio +from einops import rearrange +from stable_audio_tools import get_pretrained_model +from stable_audio_tools.inference.generation import generate_diffusion_cond +from stable_audio_tools.data.utils import read_video, merge_video_audio +from stable_audio_tools.data.utils import load_and_process_audio +import os + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# Download model +model, model_config = get_pretrained_model("HKUSTAudio/AudioX") +sample_rate = model_config["sample_rate"] +sample_size = model_config["sample_size"] +target_fps = model_config["video_fps"] +seconds_start = 0 +seconds_total = 10 + +model = model.to(device) + +# for video-to-music generation +video_path = "example/V2M_sample-1.mp4" +text_prompt = "Generate music for the video" +audio_path = None + +video_tensor = read_video(video_path, seek_time=0, duration=seconds_total, target_fps=target_fps) +audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) + +conditioning = [{ + "video_prompt": [video_tensor.unsqueeze(0)], + "text_prompt": text_prompt, + "audio_prompt": audio_tensor.unsqueeze(0), + "seconds_start": seconds_start, + "seconds_total": seconds_total +}] + +# Generate stereo audio +output = generate_diffusion_cond( + model, + steps=250, + cfg_scale=7, + conditioning=conditioning, + sample_size=sample_size, + sigma_min=0.3, + sigma_max=500, + sampler_type="dpmpp-3m-sde", + device=device +) + +# Rearrange audio batch to a single sequence +output = rearrange(output, "b d n -> d (b n)") + +# Peak normalize, clip, convert to int16, and save to file +output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() +torchaudio.save("output.wav", output, sample_rate) + +if video_path is not None and os.path.exists(video_path): + merge_video_audio(video_path, "output.wav", "output.mp4", 0, seconds_total) + +``` + + +## 🚀 Citation + +If you find our work useful, please consider citing: + +``` +@article{tian2025audiox, + title={AudioX: Diffusion Transformer for Anything-to-Audio Generation}, + author={Tian, Zeyue and Jin, Yizhu and Liu, Zhaoyang and Yuan, Ruibin and Tan, Xu and Chen, Qifeng and Xue, Wei and Guo, Yike}, + journal={arXiv preprint arXiv:2503.10522}, + year={2025} +} +``` + +## 📭 Contact + +If you have any comments or questions, feel free to contact Zeyue Tian(ztianad@connect.ust.hk). + +## License + +Please follow [MIT License](./LICENSE). \ No newline at end of file diff --git a/defaults.ini b/defaults.ini new file mode 100644 index 0000000..9f240a3 --- /dev/null +++ b/defaults.ini @@ -0,0 +1,56 @@ + +[DEFAULTS] + +#name of the run +name = stable_audio_tools + +# the batch size +batch_size = 8 + +# number of GPUs to use for training +num_gpus = 1 + +# number of nodes to use for training +num_nodes = 1 + +# Multi-GPU strategy for PyTorch Lightning +strategy = "" + +# Precision to use for training +precision = "16-mixed" + +# number of CPU workers for the DataLoader +num_workers = 8 + +# the random seed +seed = 42 + +# Batches for gradient accumulation +accum_batches = 1 + +# Number of steps between checkpoints +checkpoint_every = 10000 + +# trainer checkpoint file to restart training from +ckpt_path = '' + +# model checkpoint file to start a new training run from +pretrained_ckpt_path = '' + +# Checkpoint path for the pretransform model if needed +pretransform_ckpt_path = '' + +# configuration model specifying model hyperparameters +model_config = '' + +# configuration for datasets +dataset_config = '' + +# directory to save the checkpoints in +save_dir = '' + +# gradient_clip_val passed into PyTorch Lightning Trainer +gradient_clip_val = 0.0 + +# remove the weight norm from the pretransform model +remove_pretransform_weight_norm = '' \ No newline at end of file diff --git a/example/V2A_sample-1.mp4 b/example/V2A_sample-1.mp4 new file mode 100644 index 0000000..6440959 Binary files /dev/null and b/example/V2A_sample-1.mp4 differ diff --git a/example/V2A_sample-2.mp4 b/example/V2A_sample-2.mp4 new file mode 100644 index 0000000..4b51342 Binary files /dev/null and b/example/V2A_sample-2.mp4 differ diff --git a/example/V2A_sample-3.mp4 b/example/V2A_sample-3.mp4 new file mode 100644 index 0000000..038d788 Binary files /dev/null and b/example/V2A_sample-3.mp4 differ diff --git a/example/V2M_sample-1.mp4 b/example/V2M_sample-1.mp4 new file mode 100644 index 0000000..700b9aa Binary files /dev/null and b/example/V2M_sample-1.mp4 differ diff --git a/example/V2M_sample-2.mp4 b/example/V2M_sample-2.mp4 new file mode 100644 index 0000000..c7f5049 Binary files /dev/null and b/example/V2M_sample-2.mp4 differ diff --git a/example/V2M_sample-3.mp4 b/example/V2M_sample-3.mp4 new file mode 100644 index 0000000..02e380e Binary files /dev/null and b/example/V2M_sample-3.mp4 differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7fd26b9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/run_gradio.py b/run_gradio.py new file mode 100644 index 0000000..a303231 --- /dev/null +++ b/run_gradio.py @@ -0,0 +1,32 @@ +from stable_audio_tools import get_pretrained_model +from stable_audio_tools.interface.gradio import create_ui +import json + +import torch + +def main(args): + torch.manual_seed(42) + + interface = create_ui( + model_config_path = args.model_config, + ckpt_path=args.ckpt_path, + pretrained_name=args.pretrained_name, + pretransform_ckpt_path=args.pretransform_ckpt_path, + model_half=args.model_half + ) + interface.queue() + interface.launch(share=args.share, auth=(args.username, args.password) if args.username is not None else None) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Run gradio interface') + parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False) + parser.add_argument('--model-config', type=str, help='Path to model config', required=False) + parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False) + parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False) + parser.add_argument('--share', action='store_true', help='Create a publicly shareable link', required=False) + parser.add_argument('--username', type=str, help='Gradio username', required=False) + parser.add_argument('--password', type=str, help='Gradio password', required=False) + parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..800f09d --- /dev/null +++ b/setup.py @@ -0,0 +1,47 @@ +from setuptools import setup, find_packages + +setup( + name='AudioX', + version='0.1.0', + url='https://github.com/ZeyueT/AudioX.git', + author='AudioX, HKUST', + description='Training and inference tools for generative audio models from AudioX', + packages=find_packages(), + install_requires=[ + 'aeiou', + 'alias-free-torch==0.0.6', + 'auraloss==0.4.0', + 'descript-audio-codec==1.0.0', + 'decord==0.6.0', + 'einops', + 'einops_exts', + 'ema-pytorch==0.2.3', + 'encodec==0.1.1', + 'gradio==4.44.1', + 'gradio_client==1.3.0', + 'huggingface_hub', + 'importlib-resources==5.12.0', + 'k-diffusion==0.1.1', + 'laion-clap==1.1.6', + 'local-attention==1.8.6', + 'pandas==2.0.2', + 'pedalboard==0.9.14', + 'prefigure==0.0.9', + 'pytorch_lightning==2.4.0', + 'PyWavelets==1.4.1', + 'safetensors', + 'sentencepiece==0.1.99', + 'torch==2.4.1', + 'torchaudio==2.4.1', + 'torchmetrics==1.5.2', + 'tqdm', + 'transformers==4.46.2', + 'v-diffusion-pytorch==0.0.2', + 'vector-quantize-pytorch==1.9.14', + 'wandb', + 'webdataset==0.2.48', + 'x-transformers==1.42.11', + 'flash_attn' + ], + +) \ No newline at end of file diff --git a/stable_audio_tools/__init__.py b/stable_audio_tools/__init__.py new file mode 100644 index 0000000..22446be --- /dev/null +++ b/stable_audio_tools/__init__.py @@ -0,0 +1,2 @@ +from .models.factory import create_model_from_config, create_model_from_config_path +from .models.pretrained import get_pretrained_model \ No newline at end of file diff --git a/stable_audio_tools/inference/__init__.py b/stable_audio_tools/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable_audio_tools/inference/generation.py b/stable_audio_tools/inference/generation.py new file mode 100644 index 0000000..e8df9f2 --- /dev/null +++ b/stable_audio_tools/inference/generation.py @@ -0,0 +1,275 @@ +import numpy as np +import torch +import typing as tp +import math +from torchaudio import transforms as T + +from .utils import prepare_audio +from .sampling import sample, sample_k, sample_rf +from ..data.utils import PadCrop + +def generate_diffusion_uncond( + model, + steps: int = 250, + batch_size: int = 1, + sample_size: int = 2097152, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + return_latents = False, + **sampler_kwargs + ) -> torch.Tensor: + + # The length of the output in audio samples + audio_sample_size = sample_size + + # If this is latent diffusion, change sample_size instead to the downsampled latent size + if model.pretransform is not None: + sample_size = sample_size // model.pretransform.downsampling_ratio + + # Seed + # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) + # seed = 777 + print(seed) + torch.manual_seed(seed) + # Define the initial noise immediately after setting the seed + noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) + + if init_audio is not None: + # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. + in_sr, init_audio = init_audio + + io_channels = model.io_channels + + # For latent models, set the io_channels to the autoencoder's io_channels + if model.pretransform is not None: + io_channels = model.pretransform.io_channels + + # Prepare the initial audio for use by the model + init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + + # For latent models, encode the initial audio into latents + if model.pretransform is not None: + init_audio = model.pretransform.encode(init_audio) + + init_audio = init_audio.repeat(batch_size, 1, 1) + else: + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + init_audio = None + init_noise_level = None + + # Inpainting mask + + if init_audio is not None: + # variations + sampler_kwargs["sigma_max"] = init_noise_level + mask = None + else: + mask = None + + # Now the generative AI part: + + diff_objective = model.diffusion_objective + + if diff_objective == "v": + # k-diffusion denoising process go! + sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device) + elif diff_objective == "rectified_flow": + sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device) + + # Denoising process done. + # If this is latent diffusion, decode latents back into audio + if model.pretransform is not None and not return_latents: + sampled = model.pretransform.decode(sampled) + + # Return audio + return sampled + + +def generate_diffusion_cond( + model, + steps: int = 250, + cfg_scale=6, + conditioning: dict = None, + conditioning_tensors: tp.Optional[dict] = None, + negative_conditioning: dict = None, + negative_conditioning_tensors: tp.Optional[dict] = None, + batch_size: int = 1, + sample_size: int = 2097152, + sample_rate: int = 48000, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + mask_args: dict = None, + return_latents = False, + **sampler_kwargs + ) -> torch.Tensor: + """ + Generate audio from a prompt using a diffusion model. + + Args: + model: The diffusion model to use for generation. + steps: The number of diffusion steps to use. + cfg_scale: Classifier-free guidance scale + conditioning: A dictionary of conditioning parameters to use for generation. + conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation. + batch_size: The batch size to use for generation. + sample_size: The length of the audio to generate, in samples. + sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly) + seed: The random seed to use for generation, or -1 to use a random seed. + device: The device to use for generation. + init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation. + init_noise_level: The noise level to use when generating from an initial audio sample. + return_latents: Whether to return the latents used for generation instead of the decoded audio. + **sampler_kwargs: Additional keyword arguments to pass to the sampler. + """ + + # The length of the output in audio samples + audio_sample_size = sample_size + + # If this is latent diffusion, change sample_size instead to the downsampled latent size + if model.pretransform is not None: + sample_size = sample_size // model.pretransform.downsampling_ratio + + # Seed + # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) + # seed = 777 + # print(seed) + torch.manual_seed(seed) + # Define the initial noise immediately after setting the seed + noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) + + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + torch.backends.cudnn.benchmark = False + + # Conditioning + assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors" + if conditioning_tensors is None: + conditioning_tensors = model.conditioner(conditioning, device) + conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors) + + if negative_conditioning is not None or negative_conditioning_tensors is not None: + + if negative_conditioning_tensors is None: + negative_conditioning_tensors = model.conditioner(negative_conditioning, device) + + negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True) + else: + negative_conditioning_tensors = {} + + if init_audio is not None: + # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. + in_sr, init_audio = init_audio + + io_channels = model.io_channels + + # For latent models, set the io_channels to the autoencoder's io_channels + if model.pretransform is not None: + io_channels = model.pretransform.io_channels + + # Prepare the initial audio for use by the model + init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + + # For latent models, encode the initial audio into latents + if model.pretransform is not None: + init_audio = model.pretransform.encode(init_audio) + + init_audio = init_audio.repeat(batch_size, 1, 1) + else: + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + init_audio = None + init_noise_level = None + mask_args = None + + # Inpainting mask + if init_audio is not None and mask_args is not None: + # Cut and paste init_audio according to cropfrom, pastefrom, pasteto + # This is helpful for forward and reverse outpainting + cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size) + pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size) + pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size) + assert pastefrom < pasteto, "Paste From should be less than Paste To" + croplen = pasteto - pastefrom + if cropfrom + croplen > sample_size: + croplen = sample_size - cropfrom + cropto = cropfrom + croplen + pasteto = pastefrom + croplen + cutpaste = init_audio.new_zeros(init_audio.shape) + cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto] + #print(cropfrom, cropto, pastefrom, pasteto) + init_audio = cutpaste + # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args + mask = build_mask(sample_size, mask_args) + mask = mask.to(device) + elif init_audio is not None and mask_args is None: + # variations + sampler_kwargs["sigma_max"] = init_noise_level + mask = None + else: + mask = None + + model_dtype = next(model.model.parameters()).dtype + noise = noise.type(model_dtype) + conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()} + # Now the generative AI part: + # k-diffusion denoising process go! + + diff_objective = model.diffusion_objective + + if diff_objective == "v": + # k-diffusion denoising process go! + sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) + + elif diff_objective == "rectified_flow": + + if "sigma_min" in sampler_kwargs: + del sampler_kwargs["sigma_min"] + + if "sampler_type" in sampler_kwargs: + del sampler_kwargs["sampler_type"] + + sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) + + # v-diffusion: + del noise + del conditioning_tensors + del conditioning_inputs + torch.cuda.empty_cache() + # Denoising process done. + # If this is latent diffusion, decode latents back into audio + + if model.pretransform is not None and not return_latents: + #cast sampled latents to pretransform dtype + sampled = sampled.to(next(model.pretransform.parameters()).dtype) + sampled = model.pretransform.decode(sampled) + + return sampled + +# builds a softmask given the parameters +# returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio, +# and anything between is a mixture of old/new +# ideally 0.5 is half/half mixture but i haven't figured this out yet +def build_mask(sample_size, mask_args): + maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size) + maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size) + softnessL = round(mask_args["softnessL"]/100.0 * sample_size) + softnessR = round(mask_args["softnessR"]/100.0 * sample_size) + marination = mask_args["marination"] + # use hann windows for softening the transition (i don't know if this is correct) + hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL] + hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:] + # build the mask. + mask = torch.zeros((sample_size)) + mask[maskstart:maskend] = 1 + mask[maskstart:maskstart+softnessL] = hannL + mask[maskend-softnessR:maskend] = hannR + # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds + if marination > 0: + mask = mask * (1-marination) + return mask diff --git a/stable_audio_tools/inference/sampling.py b/stable_audio_tools/inference/sampling.py new file mode 100644 index 0000000..060dda6 --- /dev/null +++ b/stable_audio_tools/inference/sampling.py @@ -0,0 +1,235 @@ +import torch +import math +from tqdm import trange, tqdm + +import k_diffusion as K + +# Define the noise schedule and sampling loop +def get_alphas_sigmas(t): + """Returns the scaling factors for the clean image (alpha) and for the + noise (sigma), given a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + +def alpha_sigma_to_t(alpha, sigma): + """Returns a timestep, given the scaling factors for the clean image and for + the noise.""" + return torch.atan2(sigma, alpha) / math.pi * 2 + +def t_to_alpha_sigma(t): + """Returns the scaling factors for the clean image and for the noise, given + a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + + +@torch.no_grad() +def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): + """Draws samples from a model given starting noise. Euler method""" + + # Make tensor of ones to broadcast the single t values + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(sigma_max, 0, steps + 1) + + #alphas, sigmas = 1-t, t + + for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): + # Broadcast the current timestep to the correct shape + t_curr_tensor = t_curr * torch.ones( + (x.shape[0],), dtype=x.dtype, device=x.device + ) + dt = t_prev - t_curr # we solve backwards in our formulation + x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc) + + # If we are on the last timestep, output the denoised image + return x + +@torch.no_grad() +def sample(model, x, steps, eta, **extra_args): + """Draws samples from a model given starting noise. v-diffusion""" + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(1, 0, steps + 1)[:-1] + + alphas, sigmas = get_alphas_sigmas(t) + + # The sampling loop + for i in trange(steps): + + # Get the model output (v, the predicted velocity) + with torch.cuda.amp.autocast(): + v = model(x, ts * t[i], **extra_args).float() + + # Predict the noise and the denoised image + pred = x * alphas[i] - v * sigmas[i] + eps = x * sigmas[i] + v * alphas[i] + + # If we are not on the last timestep, compute the noisy image for the + # next timestep. + if i < steps - 1: + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ + (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() + adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() + + # Recombine the predicted noise and predicted denoised image in the + # correct proportions for the next step + x = pred * alphas[i + 1] + eps * adjusted_sigma + + # Add the correct amount of fresh noise + if eta: + x += torch.randn_like(x) * ddim_sigma + + # If we are on the last timestep, output the denoised image + return pred + +# Soft mask inpainting is just shrinking hard (binary) mask inpainting +# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step +def get_bmask(i, steps, mask): + strength = (i+1)/(steps) + # convert to binary mask + bmask = torch.where(mask<=strength,1,0) + return bmask + +def make_cond_model_fn(model, cond_fn): + def cond_model_fn(x, sigma, **kwargs): + with torch.enable_grad(): + x = x.detach().requires_grad_() + denoised = model(x, sigma, **kwargs) + cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() + cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) + return cond_denoised + return cond_model_fn + +# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion +# init_data is init_audio as latents (if this is latent diffusion) +# For sampling, set both init_data and mask to None +# For variations, set init_data +# For inpainting, set both init_data & mask +def sample_k( + model_fn, + noise, + init_data=None, + mask=None, + steps=100, + sampler_type="dpmpp-2m-sde", + sigma_min=0.5, + sigma_max=50, + rho=1.0, device="cuda", + callback=None, + cond_fn=None, + **extra_args + ): + + denoiser = K.external.VDenoiser(model_fn) + + if cond_fn is not None: + denoiser = make_cond_model_fn(denoiser, cond_fn) + + # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has + sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) + # Scale the initial noise by sigma + noise = noise * sigmas[0] + + wrapped_callback = callback + + + if mask is None and init_data is not None: + # VARIATION (no inpainting) + # set the initial latent to the init_data, and noise it with initial sigma + + x = init_data + noise + + elif mask is not None and init_data is not None: + # INPAINTING + bmask = get_bmask(0, steps, mask) + # initial noising + input_noised = init_data + noise + # set the initial latent to a mix of init_data and noise, based on step 0's binary mask + x = input_noised * bmask + noise * (1-bmask) + # define the inpainting callback function (Note: side effects, it mutates x) + # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105 + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)` + def inpainting_callback(args): + i = args["i"] + x = args["x"] + sigma = args["sigma"] + #denoised = args["denoised"] + # noise the init_data input with this step's appropriate amount of noise + input_noised = init_data + torch.randn_like(init_data) * sigma + # shrinking hard mask + bmask = get_bmask(i, steps, mask) + # mix input_noise with x, using binary mask + new_x = input_noised * bmask + x * (1-bmask) + # mutate x + x[:,:,:] = new_x[:,:,:] + # wrap together the inpainting callback and the user-submitted callback. + if callback is None: + wrapped_callback = inpainting_callback + else: + wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) + else: + # SAMPLING + # set the initial latent to noise + x = noise + # x = noise + + with torch.cuda.amp.autocast(): + if sampler_type == "k-heun": + return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-lms": + return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpmpp-2s-ancestral": + return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-2": + return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-fast": + return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-adaptive": + return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-2m-sde": + return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-3m-sde": + return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + +# Uses discrete Euler sampling for rectified flow models +# init_data is init_audio as latents (if this is latent diffusion) +# For sampling, set both init_data and mask to None +# For variations, set init_data +# For inpainting, set both init_data & mask +def sample_rf( + model_fn, + noise, + init_data=None, + steps=100, + sigma_max=1, + device="cuda", + callback=None, + cond_fn=None, + **extra_args + ): + + if sigma_max > 1: + sigma_max = 1 + + if cond_fn is not None: + denoiser = make_cond_model_fn(denoiser, cond_fn) + + wrapped_callback = callback + + if init_data is not None: + # VARIATION (no inpainting) + # Interpolate the init data and the noise for init audio + x = init_data * (1 - sigma_max) + noise * sigma_max + else: + # SAMPLING + # set the initial latent to noise + x = noise + + with torch.cuda.amp.autocast(): + # TODO: Add callback support + #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args) + return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) \ No newline at end of file diff --git a/stable_audio_tools/inference/utils.py b/stable_audio_tools/inference/utils.py new file mode 100644 index 0000000..6a6c0a5 --- /dev/null +++ b/stable_audio_tools/inference/utils.py @@ -0,0 +1,35 @@ +from ..data.utils import PadCrop + +from torchaudio import transforms as T + +def set_audio_channels(audio, target_channels): + if target_channels == 1: + # Convert to mono + audio = audio.mean(1, keepdim=True) + elif target_channels == 2: + # Convert to stereo + if audio.shape[1] == 1: + audio = audio.repeat(1, 2, 1) + elif audio.shape[1] > 2: + audio = audio[:, :2, :] + return audio + +def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): + + audio = audio.to(device) + + if in_sr != target_sr: + resample_tf = T.Resample(in_sr, target_sr).to(device) + audio = resample_tf(audio) + + audio = PadCrop(target_length, randomize=False)(audio) + + # Add batch dimension + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + + audio = set_audio_channels(audio, target_channels) + + return audio \ No newline at end of file diff --git a/stable_audio_tools/interface/__init__.py b/stable_audio_tools/interface/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py new file mode 100644 index 0000000..57fec31 --- /dev/null +++ b/stable_audio_tools/interface/gradio.py @@ -0,0 +1,495 @@ +import gc +import platform +import os +import subprocess as sp +import gradio as gr +import json +import torch +import torchaudio + +from aeiou.viz import audio_spectrogram_image +from einops import rearrange +from safetensors.torch import load_file +from torch.nn import functional as F +from torchaudio import transforms as T + +from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond +from ..models.factory import create_model_from_config +from ..models.pretrained import get_pretrained_model +from ..models.utils import load_ckpt_state_dict +from ..inference.utils import prepare_audio +from ..training.utils import copy_state_dict +from ..data.utils import read_video, merge_video_audio + + +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import warnings +warnings.filterwarnings("ignore", category=UserWarning) + + +device = torch.device("cpu") + +os.environ['TMPDIR'] = './tmp' + +current_model_name = None +current_model = None +current_sample_rate = None +current_sample_size = None + + + +def load_model(model_name, model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): + global model_configurations + + if pretrained_name is not None: + print(f"Loading pretrained model {pretrained_name}") + model, model_config = get_pretrained_model(pretrained_name) + elif model_config is not None and model_ckpt_path is not None: + print(f"Creating model from config") + model = create_model_from_config(model_config) + print(f"Loading model checkpoint from {model_ckpt_path}") + copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) + sample_rate = model_config["sample_rate"] + sample_size = model_config["sample_size"] + if pretransform_ckpt_path is not None: + print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}") + model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False) + print(f"Done loading pretransform") + model.to(device).eval().requires_grad_(False) + if model_half: + model.to(torch.float16) + print(f"Done loading model") + return model, model_config, sample_rate, sample_size + +def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total): + if audio_path is None: + return torch.zeros((2, int(sample_rate * seconds_total))) + audio_tensor, sr = torchaudio.load(audio_path) + start_index = int(sample_rate * seconds_start) + target_length = int(sample_rate * seconds_total) + end_index = start_index + target_length + audio_tensor = audio_tensor[:, start_index:end_index] + if audio_tensor.shape[1] < target_length: + pad_length = target_length - audio_tensor.shape[1] + audio_tensor = F.pad(audio_tensor, (pad_length, 0)) + return audio_tensor + +def generate_cond( + prompt, + negative_prompt=None, + video_file=None, + video_path=None, + audio_prompt_file=None, + audio_prompt_path=None, + seconds_start=0, + seconds_total=10, + cfg_scale=6.0, + steps=250, + preview_every=None, + seed=-1, + sampler_type="dpmpp-3m-sde", + sigma_min=0.03, + sigma_max=1000, + cfg_rescale=0.0, + use_init=False, + init_audio=None, + init_noise_level=1.0, + mask_cropfrom=None, + mask_pastefrom=None, + mask_pasteto=None, + mask_maskstart=None, + mask_maskend=None, + mask_softnessL=None, + mask_softnessR=None, + mask_marination=None, + batch_size=1 + ): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + print(f"Prompt: {prompt}") + preview_images = [] + if preview_every == 0: + preview_every = None + + try: + has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() + except Exception: + has_mps = False + if has_mps: + device = torch.device("mps") + elif torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + model_name = 'default' + cfg = model_configurations[model_name] + model_config_path = cfg.get("model_config") + ckpt_path = cfg.get("ckpt_path") + pretrained_name = cfg.get("pretrained_name") + pretransform_ckpt_path = cfg.get("pretransform_ckpt_path") + model_type = cfg.get("model_type", "diffusion_cond") + if model_config_path: + with open(model_config_path) as f: + model_config = json.load(f) + else: + model_config = None + target_fps = model_config.get("video_fps", 5) + global current_model_name, current_model, current_sample_rate, current_sample_size + if current_model is None or model_name != current_model_name: + current_model, model_config, sample_rate, sample_size = load_model( + model_name=model_name, + model_config=model_config, + model_ckpt_path=ckpt_path, + pretrained_name=pretrained_name, + pretransform_ckpt_path=pretransform_ckpt_path, + device=device, + model_half=False + ) + current_model_name = model_name + model = current_model + current_sample_rate = sample_rate + current_sample_size = sample_size + else: + model = current_model + sample_rate = current_sample_rate + sample_size = current_sample_size + if video_file is not None: + video_path = video_file.name + elif video_path: + video_path = video_path.strip() + else: + video_path = None + + if audio_prompt_file is not None: + print(f'audio_prompt_file: {audio_prompt_file}') + audio_path = audio_prompt_file.name + elif audio_prompt_path: + audio_path = audio_prompt_path.strip() + else: + audio_path = None + + Video_tensors = read_video(video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps) + audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) + + audio_tensor = audio_tensor.to(device) + seconds_input = sample_size / sample_rate + print(f'video_path: {video_path}') + + if not prompt: + prompt = "" + + conditioning = [{ + "video_prompt": [Video_tensors.unsqueeze(0)], + "text_prompt": prompt, + "audio_prompt": audio_tensor.unsqueeze(0), + "seconds_start": seconds_start, + "seconds_total": seconds_input + }] * batch_size + if negative_prompt: + negative_conditioning = [{ + "video_prompt": [Video_tensors.unsqueeze(0)], + "text_prompt": negative_prompt, + "audio_prompt": audio_tensor.unsqueeze(0), + "seconds_start": seconds_start, + "seconds_total": seconds_total + }] * batch_size + else: + negative_conditioning = None + try: + device = next(model.parameters()).device + except Exception as e: + device = next(current_model.parameters()).device + seed = int(seed) + if not use_init: + init_audio = None + input_sample_size = sample_size + if init_audio is not None: + in_sr, init_audio = init_audio + init_audio = torch.from_numpy(init_audio).float().div(32767) + if init_audio.dim() == 1: + init_audio = init_audio.unsqueeze(0) + elif init_audio.dim() == 2: + init_audio = init_audio.transpose(0, 1) + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) + init_audio = resample_tf(init_audio) + audio_length = init_audio.shape[-1] + if audio_length > sample_size: + input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + init_audio = (sample_rate, init_audio) + def progress_callback(callback_info): + nonlocal preview_images + denoised = callback_info["denoised"] + current_step = callback_info["i"] + sigma = callback_info["sigma"] + if (current_step - 1) % preview_every == 0: + if model.pretransform is not None: + denoised = model.pretransform.decode(denoised) + denoised = rearrange(denoised, "b d n -> d (b n)") + denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) + preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) + if mask_cropfrom is not None: + mask_args = { + "cropfrom": mask_cropfrom, + "pastefrom": mask_pastefrom, + "pasteto": mask_pasteto, + "maskstart": mask_maskstart, + "maskend": mask_maskend, + "softnessL": mask_softnessL, + "softnessR": mask_softnessR, + "marination": mask_marination, + } + else: + mask_args = None + if model_type == "diffusion_cond": + audio = generate_diffusion_cond( + model, + conditioning=conditioning, + negative_conditioning=negative_conditioning, + steps=steps, + cfg_scale=cfg_scale, + batch_size=batch_size, + sample_size=input_sample_size, + sample_rate=sample_rate, + seed=seed, + device=device, + sampler_type=sampler_type, + sigma_min=sigma_min, + sigma_max=sigma_max, + init_audio=init_audio, + init_noise_level=init_noise_level, + mask_args=mask_args, + callback=progress_callback if preview_every is not None else None, + scale_phi=cfg_rescale + ) + elif model_type == "diffusion_uncond": + audio = generate_diffusion_uncond( + model, + steps=steps, + batch_size=batch_size, + sample_size=input_sample_size, + seed=seed, + device=device, + sampler_type=sampler_type, + sigma_min=sigma_min, + sigma_max=sigma_max, + init_audio=init_audio, + init_noise_level=init_noise_level, + callback=progress_callback if preview_every is not None else None + ) + else: + raise ValueError(f"Unsupported model type: {model_type}") + audio = rearrange(audio, "b d n -> d (b n)") + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + file_name = os.path.basename(video_path) if video_path else "output" + output_dir = f"demo_result" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_video_path = f"{output_dir}/{file_name}" + torchaudio.save(f"{output_dir}/output.wav", audio, sample_rate) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + if video_path: + merge_video_audio(video_path, f"{output_dir}/output.wav", output_video_path, seconds_start, seconds_total) + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + del video_path + torch.cuda.empty_cache() + gc.collect() + return (output_video_path, f"{output_dir}/output.wav") + +def toggle_custom_model(selected_model): + return gr.Row.update(visible=(selected_model == "Custom Model")) + +def create_sampling_ui(model_config_map, inpainting=False): + with gr.Blocks() as demo: + gr.Markdown( + """ + # 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation + **[Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/Zeyue7/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)** + """ + ) + + with gr.Tab("Generation"): + + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt") + negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt", visible=False) + video_path = gr.Textbox(label="Video Path", placeholder="Enter video file path") + video_file = gr.File(label="Upload Video File") + audio_prompt_file = gr.File(label="Upload Audio Prompt File", visible=False) + audio_prompt_path = gr.Textbox(label="Audio Prompt Path", placeholder="Enter audio file path", visible=False) + with gr.Row(): + with gr.Column(scale=6): + with gr.Accordion("Video Params", open=False): + seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Video Seconds Start") + seconds_total_slider = gr.Slider(minimum=0, maximum=10, step=1, value=10, label="Seconds Total", interactive=False) + with gr.Row(): + with gr.Column(scale=4): + with gr.Accordion("Sampler Params", open=False): + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") + preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every") + cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG Scale") + seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") + sampler_type_dropdown = gr.Dropdown( + ["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], + label="Sampler Type", + value="dpmpp-3m-sde" + ) + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma Min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma Max") + cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG Rescale Amount") + with gr.Row(): + with gr.Column(scale=4): + with gr.Accordion("Init Audio", open=False, visible=False): + init_audio_checkbox = gr.Checkbox(label="Use Init Audio") + init_audio_input = gr.Audio(label="Init Audio") + init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init Noise Level") + gr.Markdown("## Examples") + with gr.Accordion("Click to show examples", open=False): + with gr.Row(): + gr.Markdown("**📝 Task: Text-to-Audio**") + with gr.Column(scale=1.2): + gr.Markdown("Prompt: *Typing on a keyboard*") + ex1 = gr.Button("Load Example") + with gr.Column(scale=1.2): + gr.Markdown("Prompt: *Ocean waves crashing*") + ex2 = gr.Button("Load Example") + with gr.Column(scale=1.2): + gr.Markdown("Prompt: *Footsteps in snow*") + ex3 = gr.Button("Load Example") + with gr.Row(): + gr.Markdown("**🎶 Task: Text-to-Music**") + with gr.Column(scale=1.2): + gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*") + ex4 = gr.Button("Load Example") + with gr.Column(scale=1.2): + gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*") + ex5 = gr.Button("Load Example") + with gr.Column(scale=1.2): + gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*") + ex6 = gr.Button("Load Example") + with gr.Row(): + gr.Markdown("**🎬 Task: Video-to-Audio**\nPrompt: *Generate general audio for the video*") + with gr.Column(scale=1.2): + gr.Video("example/V2A_sample-1.mp4") + ex7 = gr.Button("Load Example") + with gr.Column(scale=1.2): + gr.Video("example/V2A_sample-2.mp4") + ex8 = gr.Button("Load Example") + with gr.Column(scale=1.2): + gr.Video("example/V2A_sample-3.mp4") + ex9 = gr.Button("Load Example") + with gr.Row(): + gr.Markdown("**🎵 Task: Video-to-Music**\nPrompt: *Generate music for the video*") + with gr.Column(scale=1.2): + gr.Video("example/V2M_sample-1.mp4") + ex10 = gr.Button("Load Example") + with gr.Column(scale=1.2): + gr.Video("example/V2M_sample-2.mp4") + ex11 = gr.Button("Load Example") + with gr.Column(scale=1.2): + gr.Video("example/V2M_sample-3.mp4") + ex12 = gr.Button("Load Example") + with gr.Row(): + generate_button = gr.Button("Generate", variant='primary', scale=1) + with gr.Row(): + with gr.Column(scale=6): + video_output = gr.Video(label="Output Video", interactive=False) + audio_output = gr.Audio(label="Output Audio", interactive=False) + send_to_init_button = gr.Button("Send to Init Audio", scale=1, visible=False) + send_to_init_button.click( + fn=lambda audio: audio, + inputs=[audio_output], + outputs=[init_audio_input] + ) + inputs = [ + prompt, + negative_prompt, + video_file, + video_path, + audio_prompt_file, + audio_prompt_path, + seconds_start_slider, + seconds_total_slider, + cfg_scale_slider, + steps_slider, + preview_every_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + cfg_rescale_slider, + init_audio_checkbox, + init_audio_input, + init_noise_level_slider + ] + generate_button.click( + fn=generate_cond, + inputs=inputs, + outputs=[ + video_output, + audio_output + ], + api_name="generate" + ) + ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex3.click(lambda: ["Footsteps in snow", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex7.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3737819478", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex8.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "1900718499", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex9.click(lambda: ["Generate general audio for the video", None, None, "example/V2A_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "2289822202", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex10.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-1.mp4", None, None, 0, 10, 7.0, 100, 0, "3498087420", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex11.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-2.mp4", None, None, 0, 10, 7.0, 100, 0, "3753837734", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + ex12.click(lambda: ["Generate music for the video", None, None, "example/V2M_sample-3.mp4", None, None, 0, 10, 7.0, 100, 0, "3510832996", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) + return demo + +def create_txt2audio_ui(model_config_map): + with gr.Blocks(css=".gradio-container { max-width: 1120px; margin: auto; }") as ui: + with gr.Tab("Generation"): + create_sampling_ui(model_config_map) + return ui + +def toggle_custom_model(selected_model): + return gr.Row.update(visible=(selected_model == "Custom Model")) + +def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): + global model_configurations + global device + + try: + has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() + except Exception: + has_mps = False + + if has_mps: + device = torch.device("mps") + elif torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + print("Using device:", device) + + model_configurations = { + "default": { + "model_config": "./model/config.json", + "ckpt_path": "./model/model.ckpt" + } + } + ui = create_txt2audio_ui(model_configurations) + return ui + +if __name__ == "__main__": + ui = create_ui( + model_config_path='./model/config.json', + share=True + ) + ui.launch() diff --git a/stable_audio_tools/models/__init__.py b/stable_audio_tools/models/__init__.py new file mode 100644 index 0000000..7e27bbc --- /dev/null +++ b/stable_audio_tools/models/__init__.py @@ -0,0 +1 @@ +from .factory import create_model_from_config, create_model_from_config_path \ No newline at end of file diff --git a/stable_audio_tools/models/adp.py b/stable_audio_tools/models/adp.py new file mode 100644 index 0000000..49eb526 --- /dev/null +++ b/stable_audio_tools/models/adp.py @@ -0,0 +1,1588 @@ +# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License +# License can be found in LICENSES/LICENSE_ADP.txt + +import math +from inspect import isfunction +from math import ceil, floor, log, pi, log2 +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from packaging import version + +import torch +import torch.nn as nn +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange +from einops_exts import rearrange_many +from torch import Tensor, einsum +from torch.backends.cuda import sdp_kernel +from torch.nn import functional as F +from dac.nn.layers import Snake1d + +""" +Utils +""" + + +class ConditionedSequential(nn.Module): + def __init__(self, *modules): + super().__init__() + self.module_list = nn.ModuleList(*modules) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None): + for module in self.module_list: + x = module(x, mapping) + return x + +T = TypeVar("T") + +def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val: Optional[T]) -> T: + return val is not None + +def closest_power_2(x: float) -> int: + exponent = log2(x) + distance_fn = lambda z: abs(x - 2 ** z) # noqa + exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) + return 2 ** int(exponent_closest) + +def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: + return_dicts: Tuple[Dict, Dict] = ({}, {}) + for key in d.keys(): + no_prefix = int(not key.startswith(prefix)) + return_dicts[no_prefix][key] = d[key] + return return_dicts + +def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: + kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) + if keep_prefix: + return kwargs_with_prefix, kwargs + kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} + return kwargs_no_prefix, kwargs + +""" +Convolutional Blocks +""" +import typing as tp + +# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License +# License available in LICENSES/LICENSE_META.txt + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + dilation = self.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding)) + return super().forward(x) + +class ConvTranspose1d(nn.ConvTranspose1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + padding_total = kernel_size - stride + + y = super().forward(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if causal: + padding_right = ceil(padding_total) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + + +def Downsample1d( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor + ) + + +def Upsample1d( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3 + ), + ) + else: + return ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor + ) + + +class ConvBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + num_groups: int = 8, + use_norm: bool = True, + use_snake: bool = False + ) -> None: + super().__init__() + + self.groupnorm = ( + nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) + if use_norm + else nn.Identity() + ) + + if use_snake: + self.activation = Snake1d(in_channels) + else: + self.activation = nn.SiLU() + + self.project = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + def forward( + self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False + ) -> Tensor: + x = self.groupnorm(x) + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + x = self.activation(x) + return self.project(x, causal=causal) + + +class MappingToScaleShift(nn.Module): + def __init__( + self, + features: int, + channels: int, + ): + super().__init__() + + self.to_scale_shift = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=features, out_features=channels * 2), + ) + + def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: + scale_shift = self.to_scale_shift(mapping) + scale_shift = rearrange(scale_shift, "b c -> b c 1") + scale, shift = scale_shift.chunk(2, dim=1) + return scale, shift + + +class ResnetBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + use_norm: bool = True, + use_snake: bool = False, + num_groups: int = 8, + context_mapping_features: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_mapping = exists(context_mapping_features) + + self.block1 = ConvBlock1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + if self.use_mapping: + assert exists(context_mapping_features) + self.to_scale_shift = MappingToScaleShift( + features=context_mapping_features, channels=out_channels + ) + + self.block2 = ConvBlock1d( + in_channels=out_channels, + out_channels=out_channels, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + self.to_out = ( + Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + assert_message = "context mapping required if context_mapping_features > 0" + assert not (self.use_mapping ^ exists(mapping)), assert_message + + h = self.block1(x, causal=causal) + + scale_shift = None + if self.use_mapping: + scale_shift = self.to_scale_shift(mapping) + + h = self.block2(h, scale_shift=scale_shift, causal=causal) + + return h + self.to_out(x) + + +class Patcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + assert_message = f"out_channels must be divisible by patch_size ({patch_size})" + assert out_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels, + out_channels=out_channels // patch_size, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.block(x, mapping, causal=causal) + x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) + return x + + +class Unpatcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False + ): + super().__init__() + assert_message = f"in_channels must be divisible by patch_size ({patch_size})" + assert in_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels // patch_size, + out_channels=out_channels, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) + x = self.block(x, mapping, causal=causal) + return x + + +""" +Attention Components +""" +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +def add_mask(sim: Tensor, mask: Tensor) -> Tensor: + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + +def causal_mask(q: Tensor, k: Tensor) -> Tensor: + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) + mask = repeat(mask, "n m -> b n m", b=b) + return mask + +class AttentionBase(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + ): + super().__init__() + self.scale = head_features**-0.5 + self.num_heads = num_heads + mid_features = head_features * num_heads + out_features = default(out_features, features) + + self.to_out = nn.Linear( + in_features=mid_features, out_features=out_features + ) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False + ) -> Tensor: + # Split heads + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) + + if not self.use_flash: + if is_causal and not mask: + # Mask out future tokens for causal attention + mask = causal_mask(q, k) + + # Compute similarity matrix and add eventual mask + sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale + sim = add_mask(sim, mask) if exists(mask) else sim + + # Get attention matrix with softmax + attn = sim.softmax(dim=-1, dtype=torch.float32) + + # Compute values + out = einsum("... n m, ... m d -> ... n d", attn, v) + else: + with sdp_kernel(*self.sdp_kernel_config): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + +class Attention(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + context_features: Optional[int] = None, + causal: bool = False, + ): + super().__init__() + self.context_features = context_features + self.causal = causal + mid_features = head_features * num_heads + context_features = default(context_features, features) + + self.norm = nn.LayerNorm(features) + self.norm_context = nn.LayerNorm(context_features) + self.to_q = nn.Linear( + in_features=features, out_features=mid_features, bias=False + ) + self.to_kv = nn.Linear( + in_features=context_features, out_features=mid_features * 2, bias=False + ) + self.attention = AttentionBase( + features, + num_heads=num_heads, + head_features=head_features, + out_features=out_features, + ) + + def forward( + self, + x: Tensor, # [b, n, c] + context: Optional[Tensor] = None, # [b, m, d] + context_mask: Optional[Tensor] = None, # [b, m], false is masked, + causal: Optional[bool] = False, + ) -> Tensor: + assert_message = "You must provide a context when using context_features" + assert not self.context_features or exists(context), assert_message + # Use context if provided + context = default(context, x) + # Normalize then compute q from input and k,v from context + x, context = self.norm(x), self.norm_context(context) + + q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) + + if exists(context_mask): + # Mask out cross-attention for padding tokens + mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) + k, v = k * mask, v * mask + + # Compute and return attention + return self.attention(q, k, v, is_causal=self.causal or causal) + + +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +""" +Transformer Blocks +""" + + +class TransformerBlock(nn.Module): + def __init__( + self, + features: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.use_cross_attention = exists(context_features) and context_features > 0 + + self.attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features + ) + + if self.use_cross_attention: + self.cross_attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features, + context_features=context_features + ) + + self.feed_forward = FeedForward(features=features, multiplier=multiplier) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: + x = self.attention(x, causal=causal) + x + if self.use_cross_attention: + x = self.cross_attention(x, context=context, context_mask=context_mask) + x + x = self.feed_forward(x) + x + return x + + +""" +Transformers +""" + + +class Transformer1d(nn.Module): + def __init__( + self, + num_layers: int, + channels: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.to_in = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + Rearrange("b c t -> b t c"), + ) + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + features=channels, + head_features=head_features, + num_heads=num_heads, + multiplier=multiplier, + context_features=context_features, + ) + for i in range(num_layers) + ] + ) + + self.to_out = nn.Sequential( + Rearrange("b t c -> b c t"), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + ) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.to_in(x) + for block in self.blocks: + x = block(x, context=context, context_mask=context_mask, causal=causal) + x = self.to_out(x) + return x + + +""" +Time Embeddings +""" + + +class SinusoidalEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device, half_dim = x.device, self.dim // 2 + emb = torch.tensor(log(10000) / (half_dim - 1), device=device) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") + return torch.cat((emb.sin(), emb.cos()), dim=-1) + + +class LearnedPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x: Tensor) -> Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: + return nn.Sequential( + LearnedPositionalEmbedding(dim), + nn.Linear(in_features=dim + 1, out_features=out_features), + ) + + +""" +Encoder/Decoder Components +""" + + +class DownsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_groups: int, + num_layers: int, + kernel_multiplier: int = 2, + use_pre_downsample: bool = True, + use_skip: bool = False, + use_snake: bool = False, + extract_channels: int = 0, + context_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + self.use_pre_downsample = use_pre_downsample + self.use_skip = use_skip + self.use_transformer = num_transformer_blocks > 0 + self.use_extract = extract_channels > 0 + self.use_context = context_channels > 0 + + channels = out_channels if use_pre_downsample else in_channels + + self.downsample = Downsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + kernel_multiplier=kernel_multiplier, + ) + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + context_channels if i == 0 else channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for i in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + channels: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: + + if self.use_pre_downsample: + x = self.downsample(x) + + if self.use_context and exists(channels): + x = torch.cat([x, channels], dim=1) + + skips = [] + for block in self.blocks: + x = block(x, mapping=mapping, causal=causal) + skips += [x] if self.use_skip else [] + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + skips += [x] if self.use_skip else [] + + if not self.use_pre_downsample: + x = self.downsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return (x, skips) if self.use_skip else x + + +class UpsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_layers: int, + num_groups: int, + use_nearest: bool = False, + use_pre_upsample: bool = False, + use_skip: bool = False, + use_snake: bool = False, + skip_channels: int = 0, + use_skip_scale: bool = False, + extract_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + + self.use_extract = extract_channels > 0 + self.use_pre_upsample = use_pre_upsample + self.use_transformer = num_transformer_blocks > 0 + self.use_skip = use_skip + self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 + + channels = out_channels if use_pre_upsample else in_channels + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + skip_channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for _ in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.upsample = Upsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + use_nearest=use_nearest, + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: + return torch.cat([x, skip * self.skip_scale], dim=1) + + def forward( + self, + x: Tensor, + *, + skips: Optional[List[Tensor]] = None, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, Tensor], Tensor]: + + if self.use_pre_upsample: + x = self.upsample(x) + + for block in self.blocks: + x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x + x = block(x, mapping=mapping, causal=causal) + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + + if not self.use_pre_upsample: + x = self.upsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return x + + +class BottleneckBlock1d(nn.Module): + def __init__( + self, + channels: int, + *, + num_groups: int, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + self.use_transformer = num_transformer_blocks > 0 + + self.pre_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.post_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Tensor: + x = self.pre_block(x, mapping=mapping, causal=causal) + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + x = self.post_block(x, mapping=mapping, causal=causal) + return x + + +""" +UNet +""" + + +class UNet1d(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + multipliers: Sequence[int], + factors: Sequence[int], + num_blocks: Sequence[int], + attentions: Sequence[int], + patch_size: int = 1, + resnet_groups: int = 8, + use_context_time: bool = True, + kernel_multiplier_downsample: int = 2, + use_nearest_upsample: bool = False, + use_skip_scale: bool = True, + use_snake: bool = False, + use_stft: bool = False, + use_stft_context: bool = False, + out_channels: Optional[int] = None, + context_features: Optional[int] = None, + context_features_multiplier: int = 4, + context_channels: Optional[Sequence[int]] = None, + context_embedding_features: Optional[int] = None, + **kwargs, + ): + super().__init__() + out_channels = default(out_channels, in_channels) + context_channels = list(default(context_channels, [])) + num_layers = len(multipliers) - 1 + use_context_features = exists(context_features) + use_context_channels = len(context_channels) > 0 + context_mapping_features = None + + attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) + + self.num_layers = num_layers + self.use_context_time = use_context_time + self.use_context_features = use_context_features + self.use_context_channels = use_context_channels + self.use_stft = use_stft + self.use_stft_context = use_stft_context + + self.context_features = context_features + context_channels_pad_length = num_layers + 1 - len(context_channels) + context_channels = context_channels + [0] * context_channels_pad_length + self.context_channels = context_channels + self.context_embedding_features = context_embedding_features + + if use_context_channels: + has_context = [c > 0 for c in context_channels] + self.has_context = has_context + self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] + + assert ( + len(factors) == num_layers + and len(attentions) >= num_layers + and len(num_blocks) == num_layers + ) + + if use_context_time or use_context_features: + context_mapping_features = channels * context_features_multiplier + + self.to_mapping = nn.Sequential( + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + ) + + if use_context_time: + assert exists(context_mapping_features) + self.to_time = nn.Sequential( + TimePositionalEmbedding( + dim=channels, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_context_features: + assert exists(context_features) and exists(context_mapping_features) + self.to_features = nn.Sequential( + nn.Linear( + in_features=context_features, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_stft: + stft_kwargs, kwargs = groupby("stft_", kwargs) + assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" + stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 + in_channels *= stft_channels + out_channels *= stft_channels + context_channels[0] *= stft_channels if use_stft_context else 1 + assert exists(in_channels) and exists(out_channels) + self.stft = STFT(**stft_kwargs) + + assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" + + self.to_in = Patcher( + in_channels=in_channels + context_channels[0], + out_channels=channels * multipliers[0], + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + self.downsamples = nn.ModuleList( + [ + DownsampleBlock1d( + in_channels=channels * multipliers[i], + out_channels=channels * multipliers[i + 1], + context_mapping_features=context_mapping_features, + context_channels=context_channels[i + 1], + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i], + factor=factors[i], + kernel_multiplier=kernel_multiplier_downsample, + num_groups=resnet_groups, + use_pre_downsample=True, + use_skip=True, + use_snake=use_snake, + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in range(num_layers) + ] + ) + + self.bottleneck = BottleneckBlock1d( + channels=channels * multipliers[-1], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_groups=resnet_groups, + num_transformer_blocks=attentions[-1], + use_snake=use_snake, + **attention_kwargs, + ) + + self.upsamples = nn.ModuleList( + [ + UpsampleBlock1d( + in_channels=channels * multipliers[i + 1], + out_channels=channels * multipliers[i], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i] + (1 if attentions[i] else 0), + factor=factors[i], + use_nearest=use_nearest_upsample, + num_groups=resnet_groups, + use_skip_scale=use_skip_scale, + use_pre_upsample=False, + use_skip=True, + use_snake=use_snake, + skip_channels=channels * multipliers[i + 1], + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in reversed(range(num_layers)) + ] + ) + + self.to_out = Unpatcher( + in_channels=channels * multipliers[0], + out_channels=out_channels, + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def get_channels( + self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 + ) -> Optional[Tensor]: + """Gets context channels at `layer` and checks that shape is correct""" + use_context_channels = self.use_context_channels and self.has_context[layer] + if not use_context_channels: + return None + assert exists(channels_list), "Missing context" + # Get channels index (skipping zero channel contexts) + channels_id = self.channels_ids[layer] + # Get channels + channels = channels_list[channels_id] + message = f"Missing context for layer {layer} at index {channels_id}" + assert exists(channels), message + # Check channels + num_channels = self.context_channels[layer] + message = f"Expected context with {num_channels} channels at idx {channels_id}" + assert channels.shape[1] == num_channels, message + # STFT channels if requested + channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa + return channels + + def get_mapping( + self, time: Optional[Tensor] = None, features: Optional[Tensor] = None + ) -> Optional[Tensor]: + """Combines context time features and features into mapping""" + items, mapping = [], None + # Compute time features + if self.use_context_time: + assert_message = "use_context_time=True but no time features provided" + assert exists(time), assert_message + items += [self.to_time(time)] + # Compute features + if self.use_context_features: + assert_message = "context_features exists but no features provided" + assert exists(features), assert_message + items += [self.to_features(features)] + # Compute joint mapping + if self.use_context_time or self.use_context_features: + mapping = reduce(torch.stack(items), "n b m -> b m", "sum") + mapping = self.to_mapping(mapping) + return mapping + + def forward( + self, + x: Tensor, + time: Optional[Tensor] = None, + *, + features: Optional[Tensor] = None, + channels_list: Optional[Sequence[Tensor]] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False, + ) -> Tensor: + channels = self.get_channels(channels_list, layer=0) + # Apply stft if required + x = self.stft.encode1d(x) if self.use_stft else x # type: ignore + # Concat context channels at layer 0 if provided + x = torch.cat([x, channels], dim=1) if exists(channels) else x + # Compute mapping from time and features + mapping = self.get_mapping(time, features) + x = self.to_in(x, mapping, causal=causal) + skips_list = [x] + + for i, downsample in enumerate(self.downsamples): + channels = self.get_channels(channels_list, layer=i + 1) + x, skips = downsample( + x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal + ) + skips_list += [skips] + + x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + for i, upsample in enumerate(self.upsamples): + skips = skips_list.pop() + x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + x += skips_list.pop() + x = self.to_out(x, mapping, causal=causal) + x = self.stft.decode1d(x) if self.use_stft else x + + return x + + +""" Conditioning Modules """ + + +class FixedEmbedding(nn.Module): + def __init__(self, max_length: int, features: int): + super().__init__() + self.max_length = max_length + self.embedding = nn.Embedding(max_length, features) + + def forward(self, x: Tensor) -> Tensor: + batch_size, length, device = *x.shape[0:2], x.device + assert_message = "Input sequence length must be <= max_length" + assert length <= self.max_length, assert_message + position = torch.arange(length, device=device) + fixed_embedding = self.embedding(position) + fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) + return fixed_embedding + + +def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: + if proba == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif proba == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) + + +class UNetCFG1d(UNet1d): + + """UNet1d with Classifier-Free Guidance""" + + def __init__( + self, + context_embedding_max_length: int, + context_embedding_features: int, + use_xattn_time: bool = False, + **kwargs, + ): + super().__init__( + context_embedding_features=context_embedding_features, **kwargs + ) + + self.use_xattn_time = use_xattn_time + + if use_xattn_time: + assert exists(context_embedding_features) + self.to_time_embedding = nn.Sequential( + TimePositionalEmbedding( + dim=kwargs["channels"], out_features=context_embedding_features + ), + nn.GELU(), + ) + + context_embedding_max_length += 1 # Add one for time embedding + + self.fixed_embedding = FixedEmbedding( + max_length=context_embedding_max_length, features=context_embedding_features + ) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + embedding: Tensor, + embedding_mask: Optional[Tensor] = None, + embedding_scale: float = 1.0, + embedding_mask_proba: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + scale_phi: float = 0.4, + negative_embedding: Optional[Tensor] = None, + negative_embedding_mask: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + b, device = embedding.shape[0], embedding.device + + if self.use_xattn_time: + embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) + + if embedding_mask is not None: + embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) + + fixed_embedding = self.fixed_embedding(embedding) + + if embedding_mask_proba > 0.0: + # Randomly mask embedding + batch_mask = rand_bool( + shape=(b, 1, 1), proba=embedding_mask_proba, device=device + ) + embedding = torch.where(batch_mask, fixed_embedding, embedding) + + if embedding_scale != 1.0: + if batch_cfg: + batch_x = torch.cat([x, x], dim=0) + batch_time = torch.cat([time, time], dim=0) + + if negative_embedding is not None: + if negative_embedding_mask is not None: + negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) + + negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) + + batch_embed = torch.cat([embedding, negative_embedding], dim=0) + + else: + batch_embed = torch.cat([embedding, fixed_embedding], dim=0) + + batch_mask = None + if embedding_mask is not None: + batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) + + batch_features = None + features = kwargs.pop("features", None) + if self.use_context_features: + batch_features = torch.cat([features, features], dim=0) + + batch_channels = None + channels_list = kwargs.pop("channels_list", None) + if self.use_context_channels: + batch_channels = [] + for channels in channels_list: + batch_channels += [torch.cat([channels, channels], dim=0)] + + # Compute both normal and fixed embedding outputs + batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) + out, out_masked = batch_out.chunk(2, dim=0) + + else: + # Compute both normal and fixed embedding outputs + out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) + + out_cfg = out_masked + (out - out_masked) * embedding_scale + + if rescale_cfg: + + out_std = out.std(dim=1, keepdim=True) + out_cfg_std = out_cfg.std(dim=1, keepdim=True) + + return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg + + else: + + return out_cfg + + else: + return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + + +class UNetNCCA1d(UNet1d): + + """UNet1d with Noise Channel Conditioning Augmentation""" + + def __init__(self, context_features: int, **kwargs): + super().__init__(context_features=context_features, **kwargs) + self.embedder = NumberEmbedder(features=context_features) + + def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: + x = x if torch.is_tensor(x) else torch.tensor(x) + return x.expand(shape) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + channels_list: Sequence[Tensor], + channels_augmentation: Union[ + bool, Sequence[bool], Sequence[Sequence[bool]], Tensor + ] = False, + channels_scale: Union[ + float, Sequence[float], Sequence[Sequence[float]], Tensor + ] = 0, + **kwargs, + ) -> Tensor: + b, n = x.shape[0], len(channels_list) + channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) + channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) + + # Augmentation (for each channel list item) + for i in range(n): + scale = channels_scale[:, i] * channels_augmentation[:, i] + scale = rearrange(scale, "b -> b 1 1") + item = channels_list[i] + channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa + + # Scale embedding (sum reduction if more than one channel list item) + channels_scale_emb = self.embedder(channels_scale) + channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") + + return super().forward( + x=x, + time=time, + channels_list=channels_list, + features=channels_scale_emb, + **kwargs, + ) + + +class UNetAll1d(UNetCFG1d, UNetNCCA1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): # type: ignore + return UNetCFG1d.forward(self, *args, **kwargs) + + +def XUNet1d(type: str = "base", **kwargs) -> UNet1d: + if type == "base": + return UNet1d(**kwargs) + elif type == "all": + return UNetAll1d(**kwargs) + elif type == "cfg": + return UNetCFG1d(**kwargs) + elif type == "ncca": + return UNetNCCA1d(**kwargs) + else: + raise ValueError(f"Unknown XUNet1d type: {type}") + +class NumberEmbedder(nn.Module): + def __init__( + self, + features: int, + dim: int = 256, + ): + super().__init__() + self.features = features + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + + def forward(self, x: Union[List[float], Tensor]) -> Tensor: + if not torch.is_tensor(x): + device = next(self.embedding.parameters()).device + x = torch.tensor(x, device=device) + assert isinstance(x, Tensor) + shape = x.shape + x = rearrange(x, "... -> (...)") + embedding = self.embedding(x) + x = embedding.view(*shape, self.features) + return x # type: ignore + + +""" +Audio Transforms +""" + + +class STFT(nn.Module): + """Helper for torch stft and istft""" + + def __init__( + self, + num_fft: int = 1023, + hop_length: int = 256, + window_length: Optional[int] = None, + length: Optional[int] = None, + use_complex: bool = False, + ): + super().__init__() + self.num_fft = num_fft + self.hop_length = default(hop_length, floor(num_fft // 4)) + self.window_length = default(window_length, num_fft) + self.length = length + self.register_buffer("window", torch.hann_window(self.window_length)) + self.use_complex = use_complex + + def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: + b = wave.shape[0] + wave = rearrange(wave, "b c t -> (b c) t") + + stft = torch.stft( + wave, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + return_complex=True, + normalized=True, + ) + + if self.use_complex: + # Returns real and imaginary + stft_a, stft_b = stft.real, stft.imag + else: + # Returns magnitude and phase matrices + magnitude, phase = torch.abs(stft), torch.angle(stft) + stft_a, stft_b = magnitude, phase + + return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) + + def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: + b, l = stft_a.shape[0], stft_a.shape[-1] # noqa + length = closest_power_2(l * self.hop_length) + + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") + + if self.use_complex: + real, imag = stft_a, stft_b + else: + magnitude, phase = stft_a, stft_b + real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) + + stft = torch.stack([real, imag], dim=-1) + + wave = torch.istft( + stft, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + length=default(self.length, length), + normalized=True, + ) + + return rearrange(wave, "(b c) t -> b c t", b=b) + + def encode1d( + self, wave: Tensor, stacked: bool = True + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + stft_a, stft_b = self.encode(wave) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") + return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) + + def decode1d(self, stft_pair: Tensor) -> Tensor: + f = self.num_fft // 2 + 1 + stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) + return self.decode(stft_a, stft_b) diff --git a/stable_audio_tools/models/autoencoders.py b/stable_audio_tools/models/autoencoders.py new file mode 100644 index 0000000..7c4bdbd --- /dev/null +++ b/stable_audio_tools/models/autoencoders.py @@ -0,0 +1,794 @@ +import torch +import math +import numpy as np + +from torch import nn +from torch.nn import functional as F +from torchaudio import transforms as T +from alias_free_torch import Activation1d +from dac.nn.layers import WNConv1d, WNConvTranspose1d +from typing import Literal, Dict, Any + +from ..inference.sampling import sample +from ..inference.utils import prepare_audio +from .blocks import SnakeBeta +from .bottleneck import Bottleneck, DiscreteBottleneck +from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper +from .factory import create_pretransform_from_config, create_bottleneck_from_config +from .pretransforms import Pretransform + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + +class ResidualUnit(nn.Module): + def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): + super().__init__() + + self.dilation = dilation + + padding = (dilation * (7-1)) // 2 + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=7, dilation=dilation, padding=padding), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=out_channels, out_channels=out_channels, + kernel_size=1) + ) + + def forward(self, x): + res = x + + #x = checkpoint(self.layers, x) + x = self.layers(x) + + return x + res + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): + super().__init__() + + self.layers = nn.Sequential( + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=9, use_snake=use_snake), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + ) + + def forward(self, x): + return self.layers(x) + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, + stride=1, + bias=False, + padding='same') + ) + else: + upsample_layer = WNConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + upsample_layer, + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=9, use_snake=use_snake), + ) + + def forward(self, x): + return self.layers(x) + +class OobleckEncoder(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) + ] + + for i in range(self.depth-1): + layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), + WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__(self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + ] + + for i in range(self.depth-1, 0, -1): + layers += [DecoderBlock( + in_channels=c_mults[i]*channels, + out_channels=c_mults[i-1]*channels, + stride=strides[i-1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), + WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class DACEncoderWrapper(nn.Module): + def __init__(self, in_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Encoder as DACEncoder + + latent_dim = kwargs.pop("latent_dim", None) + + encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) + self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) + self.latent_dim = latent_dim + + # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility + self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + + if in_channels != 1: + self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) + + def forward(self, x): + x = self.encoder(x) + x = self.proj_out(x) + return x + +class DACDecoderWrapper(nn.Module): + def __init__(self, latent_dim, out_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Decoder as DACDecoder + + self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) + + self.latent_dim = latent_dim + + def forward(self, x): + return self.decoder(x) + +class AudioAutoencoder(nn.Module): + def __init__( + self, + encoder, + decoder, + latent_dim, + downsampling_ratio, + sample_rate, + io_channels=2, + bottleneck: Bottleneck = None, + pretransform: Pretransform = None, + in_channels = None, + out_channels = None, + soft_clip = False + ): + super().__init__() + + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + + self.min_length = self.downsampling_ratio + + if in_channels is not None: + self.in_channels = in_channels + + if out_channels is not None: + self.out_channels = out_channels + + self.bottleneck = bottleneck + + self.encoder = encoder + + self.decoder = decoder + + self.pretransform = pretransform + + self.soft_clip = soft_clip + + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete + + def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): + + info = {} + + if self.pretransform is not None and not skip_pretransform: + if self.pretransform.enable_grad: + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + else: + with torch.no_grad(): + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + + if self.encoder is not None: + if iterate_batch: + latents = [] + for i in range(audio.shape[0]): + latents.append(self.encoder(audio[i:i+1])) + latents = torch.cat(latents, dim=0) + else: + latents = self.encoder(audio) + else: + latents = audio + + if self.bottleneck is not None: + # TODO: Add iterate batch logic, needs to merge the info dicts + latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) + + info.update(bottleneck_info) + + if return_info: + return latents, info + + return latents + + def decode(self, latents, iterate_batch=False, **kwargs): + + if self.bottleneck is not None: + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.bottleneck.decode(latents[i:i+1])) + latents = torch.cat(decoded, dim=0) + else: + latents = self.bottleneck.decode(latents) + + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.decoder(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + decoded = self.decoder(latents, **kwargs) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + if iterate_batch: + decodeds = [] + for i in range(decoded.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + if iterate_batch: + decodeds = [] + for i in range(latents.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + + if self.soft_clip: + decoded = torch.tanh(decoded) + + return decoded + + def decode_tokens(self, tokens, **kwargs): + ''' + Decode discrete tokens to audio + Only works with discrete autoencoders + ''' + + assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" + + latents = self.bottleneck.decode_tokens(tokens, **kwargs) + + return self.decode(latents, **kwargs) + + + def preprocess_audio_for_encoder(self, audio, in_sr): + ''' + Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. + If the model is mono, stereo audio will be converted to mono. + Audio will be silence-padded to be a multiple of the model's downsampling ratio. + Audio will be resampled to the model's sample rate. + The output will have batch size 1 and be shape (1 x Channels x Length) + ''' + return self.preprocess_audio_list_for_encoder([audio], [in_sr]) + + def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): + ''' + Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. + The audio in that list can be of different lengths and channels. + in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. + All audio will be resampled to the model's sample rate. + Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. + If the model is mono, all audio will be converted to mono. + The output will be a tensor of shape (Batch x Channels x Length) + ''' + batch_size = len(audio_list) + if isinstance(in_sr_list, int): + in_sr_list = [in_sr_list]*batch_size + assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" + new_audio = [] + max_length = 0 + # resample & find the max length + for i in range(batch_size): + audio = audio_list[i] + in_sr = in_sr_list[i] + if len(audio.shape) == 3 and audio.shape[0] == 1: + # batchsize 1 was given by accident. Just squeeze it. + audio = audio.squeeze(0) + elif len(audio.shape) == 1: + # Mono signal, channel dimension is missing, unsqueeze it in + audio = audio.unsqueeze(0) + assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" + # Resample audio + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) + audio = resample_tf(audio) + new_audio.append(audio) + if audio.shape[-1] > max_length: + max_length = audio.shape[-1] + # Pad every audio to the same length, multiple of model's downsampling ratio + padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length + for i in range(batch_size): + # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model + new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, + target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) + # convert to tensor + return torch.stack(new_audio) + + def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. + If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. + Overlap and chunk_size params are both measured in number of latents (not audio samples) + # and therefore you likely could use the same values with decode_audio. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Encode the entire audio in parallel + return self.encode(audio, **kwargs) + else: + # CHUNKED ENCODING + # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) + samples_per_latent = self.downsampling_ratio + total_size = audio.shape[2] # in samples + batch_size = audio.shape[0] + chunk_size *= samples_per_latent # converting metric in latents to samples + overlap *= samples_per_latent # converting metric in latents to samples + hop_size = chunk_size - overlap + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = audio[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = audio[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # Note: y_size might be a different value from the latent length used in diffusion training + # because we can encode audio of varying lengths + # However, the audio should've been padded to a multiple of samples_per_latent by now. + y_size = total_size // samples_per_latent + # Create an empty latent, we will populate it with chunks as we encode them + y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # encode the chunk + y_chunk = self.encode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size // samples_per_latent + t_end = t_start + chunk_size // samples_per_latent + # remove the edges of the overlaps + ol = overlap//samples_per_latent//2 + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Decode latents to audio. + If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Decode the entire latent in parallel + return self.decode(latents, **kwargs) + else: + # chunked decoding + hop_size = chunk_size - overlap + total_size = latents.shape[2] + batch_size = latents.shape[0] + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = latents[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = latents[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # samples_per_latent is just the downsampling ratio + samples_per_latent = self.downsampling_ratio + # Create an empty waveform, we will populate it with chunks as decode them + y_size = total_size * samples_per_latent + y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # decode the chunk + y_chunk = self.decode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size * samples_per_latent + t_end = t_start + chunk_size * samples_per_latent + # remove the edges of the overlaps + ol = (overlap//2) * samples_per_latent + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + +class DiffusionAutoencoder(AudioAutoencoder): + def __init__( + self, + diffusion: ConditionedDiffusionModel, + diffusion_downsampling_ratio, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.diffusion = diffusion + + self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio + + if self.encoder is not None: + # Shrink the initial encoder parameters to avoid saturated latents + with torch.no_grad(): + for param in self.encoder.parameters(): + param *= 0.5 + + def decode(self, latents, steps=100): + + upsampled_length = latents.shape[2] * self.downsampling_ratio + + if self.bottleneck is not None: + latents = self.bottleneck.decode(latents) + + if self.decoder is not None: + latents = self.decode(latents) + + # Upsample latents to match diffusion length + if latents.shape[2] != upsampled_length: + latents = F.interpolate(latents, size=upsampled_length, mode='nearest') + + noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) + decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + decoded = self.pretransform.decode(decoded) + + return decoded + +# AE factories + +def create_encoder_from_config(encoder_config: Dict[str, Any]): + encoder_type = encoder_config.get("type", None) + assert encoder_type is not None, "Encoder type must be specified" + + if encoder_type == "oobleck": + encoder = OobleckEncoder( + **encoder_config["config"] + ) + + elif encoder_type == "seanet": + from encodec.modules import SEANetEncoder + seanet_encoder_config = encoder_config["config"] + + #SEANet encoder expects strides in reverse order + seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) + encoder = SEANetEncoder( + **seanet_encoder_config + ) + elif encoder_type == "dac": + dac_config = encoder_config["config"] + + encoder = DACEncoderWrapper(**dac_config) + elif encoder_type == "local_attn": + from .local_attention import TransformerEncoder1D + + local_attn_config = encoder_config["config"] + + encoder = TransformerEncoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown encoder type {encoder_type}") + + requires_grad = encoder_config.get("requires_grad", True) + if not requires_grad: + for param in encoder.parameters(): + param.requires_grad = False + + return encoder + +def create_decoder_from_config(decoder_config: Dict[str, Any]): + decoder_type = decoder_config.get("type", None) + assert decoder_type is not None, "Decoder type must be specified" + + if decoder_type == "oobleck": + decoder = OobleckDecoder( + **decoder_config["config"] + ) + elif decoder_type == "seanet": + from encodec.modules import SEANetDecoder + + decoder = SEANetDecoder( + **decoder_config["config"] + ) + elif decoder_type == "dac": + dac_config = decoder_config["config"] + + decoder = DACDecoderWrapper(**dac_config) + elif decoder_type == "local_attn": + from .local_attention import TransformerDecoder1D + + local_attn_config = decoder_config["config"] + + decoder = TransformerDecoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown decoder type {decoder_type}") + + requires_grad = decoder_config.get("requires_grad", True) + if not requires_grad: + for param in decoder.parameters(): + param.requires_grad = False + + return decoder + +def create_autoencoder_from_config(config: Dict[str, Any]): + + ae_config = config["model"] + + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + + bottleneck = ae_config.get("bottleneck", None) + + latent_dim = ae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = ae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = ae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + in_channels = ae_config.get("in_channels", None) + out_channels = ae_config.get("out_channels", None) + + pretransform = ae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + soft_clip = ae_config["decoder"].get("soft_clip", False) + + return AudioAutoencoder( + encoder, + decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=pretransform, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip + ) + +def create_diffAE_from_config(config: Dict[str, Any]): + + diffae_config = config["model"] + + if "encoder" in diffae_config: + encoder = create_encoder_from_config(diffae_config["encoder"]) + else: + encoder = None + + if "decoder" in diffae_config: + decoder = create_decoder_from_config(diffae_config["decoder"]) + else: + decoder = None + + diffusion_model_type = diffae_config["diffusion"]["type"] + + if diffusion_model_type == "DAU1d": + diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "adp_1d": + diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "dit": + diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) + + latent_dim = diffae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = diffae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = diffae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + bottleneck = diffae_config.get("bottleneck", None) + + pretransform = diffae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + diffusion_downsampling_ratio = None, + + if diffusion_model_type == "DAU1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) + elif diffusion_model_type == "adp_1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"]) + elif diffusion_model_type == "dit": + diffusion_downsampling_ratio = 1 + + return DiffusionAutoencoder( + encoder=encoder, + decoder=decoder, + diffusion=diffusion, + io_channels=io_channels, + sample_rate=sample_rate, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + diffusion_downsampling_ratio=diffusion_downsampling_ratio, + bottleneck=bottleneck, + pretransform=pretransform + ) diff --git a/stable_audio_tools/models/blocks.py b/stable_audio_tools/models/blocks.py new file mode 100644 index 0000000..3c827fd --- /dev/null +++ b/stable_audio_tools/models/blocks.py @@ -0,0 +1,339 @@ +from functools import reduce +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.backends.cuda import sdp_kernel +from packaging import version + +from dac.nn.layers import Snake1d + +class ResidualBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + +class ResConvBlock(ResidualBlock): + def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): + skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) + super().__init__([ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], skip) + +class SelfAttention1d(nn.Module): + def __init__(self, c_in, n_head=1, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm = nn.GroupNorm(1, c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv1d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward(self, input): + n, c, s = input.shape + qkv = self.qkv_proj(self.norm(input)) + qkv = qkv.view( + [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3]**-0.25 + + if self.use_flash: + with sdp_kernel(*self.sdp_kernel_config): + y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) + else: + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) + + + return input + self.dropout(self.out_proj(y)) + +class SkipBlock(nn.Module): + def __init__(self, *main): + super().__init__() + self.main = nn.Sequential(*main) + + def forward(self, input): + return torch.cat([self.main(input), input], dim=1) + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn( + [out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + +def expand_to_planes(input, shape): + return input[..., None].repeat([1, 1, shape[2]]) + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} + +class Downsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, (self.pad,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv1d(x, weight, stride=2) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +class Upsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + +def Downsample1d_2( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor, + padding=factor * (kernel_multiplier // 2), + ) + + +def Upsample1d_2( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + else: + return nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor, + padding=factor // 2 + factor % 2, + output_padding=factor % 2, + ) + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +#rms_norm = torch.compile(rms_norm) + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) + +def normalize(x, eps=1e-4): + dim = list(range(1, x.ndim)) + n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) + alpha = np.sqrt(n.numel() / x.numel()) + return x / torch.add(eps, n, alpha=alpha) + +class ForcedWNConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1): + super().__init__() + self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(self.weight)) + + fan_in = self.weight[0].numel() + + w = normalize(self.weight) / math.sqrt(fan_in) + + return F.conv1d(x, w, padding='same') + +# Kernels + +use_compile = True + +def compile(function, *args, **kwargs): + if not use_compile: + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +# Layers + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, fix_scale = False, eps=1e-6): + super().__init__() + self.eps = eps + + if fix_scale: + self.register_buffer("scale", torch.ones(shape)) + else: + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +# try: +# snake_beta = torch.compile(snake_beta) +# except RuntimeError: +# pass + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x \ No newline at end of file diff --git a/stable_audio_tools/models/bottleneck.py b/stable_audio_tools/models/bottleneck.py new file mode 100644 index 0000000..5e81cab --- /dev/null +++ b/stable_audio_tools/models/bottleneck.py @@ -0,0 +1,355 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from einops import rearrange +from vector_quantize_pytorch import ResidualVQ, FSQ +from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ + +class Bottleneck(nn.Module): + def __init__(self, is_discrete: bool = False): + super().__init__() + + self.is_discrete = is_discrete + + def encode(self, x, return_info=False, **kwargs): + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError + +class DiscreteBottleneck(Bottleneck): + def __init__(self, num_quantizers, codebook_size, tokens_id): + super().__init__(is_discrete=True) + + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.tokens_id = tokens_id + + def decode_tokens(self, codes, **kwargs): + raise NotImplementedError + +class TanhBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + self.tanh = nn.Tanh() + + def encode(self, x, return_info=False): + info = {} + + x = torch.tanh(x) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def vae_sample(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + +class VAEBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False, **kwargs): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["kl"] = kl + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def compute_mean_kernel(x, y): + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + +def compute_mmd(latents): + latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) + noise = torch.randn_like(latents_reshaped) + + latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) + noise_kernel = compute_mean_kernel(noise, noise) + latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) + + mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel + return mmd.mean() + +class WassersteinBottleneck(Bottleneck): + def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False): + super().__init__(is_discrete=False) + + self.noise_augment_dim = noise_augment_dim + self.bypass_mmd = bypass_mmd + + def encode(self, x, return_info=False): + info = {} + + if self.training and return_info: + if self.bypass_mmd: + mmd = torch.tensor(0.0) + else: + mmd = compute_mmd(x) + + info["mmd"] = mmd + + if return_info: + return x, info + + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + +class L2Bottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False): + info = {} + + x = F.normalize(x, dim=1) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return F.normalize(x, dim=1) + +class RVQBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False, **kwargs): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class RVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False): + info = {} + + x, kl = vae_sample(*x.chunk(2, dim=1)) + + info["kl"] = kl + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class DACRVQBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + self.noise_augment_dim = noise_augment_dim + + def encode(self, x, return_info=False, **kwargs): + info = {} + + info["pre_quantizer"] = x + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class DACRVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, n_quantizers: int = None): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["pre_quantizer"] = x + info["kl"] = kl + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class FSQBottleneck(DiscreteBottleneck): + def __init__(self, noise_augment_dim=0, **kwargs): + super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices") + + self.noise_augment_dim = noise_augment_dim + + self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]) + + def encode(self, x, return_info=False): + info = {} + + orig_dtype = x.dtype + x = x.float() + + x = rearrange(x, "b c n -> b n c") + x, indices = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + x = x.to(orig_dtype) + + # Reorder indices to match the expected format + indices = rearrange(indices, "b n q -> b q n") + + info["quantizer_indices"] = indices + + if return_info: + return x, info + else: + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + def decode_tokens(self, tokens, **kwargs): + latents = self.quantizer.indices_to_codes(tokens) + + return self.decode(latents, **kwargs) \ No newline at end of file diff --git a/stable_audio_tools/models/codebook_patterns.py b/stable_audio_tools/models/codebook_patterns.py new file mode 100644 index 0000000..f9bd2a9 --- /dev/null +++ b/stable_audio_tools/models/codebook_patterns.py @@ -0,0 +1,545 @@ +# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License +# License available in LICENSES/LICENSE_META.txt + +from collections import namedtuple +from dataclasses import dataclass +from functools import lru_cache +import logging +import typing as tp + +from abc import ABC, abstractmethod +import torch + +LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) +PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates +logger = logging.getLogger(__name__) + + +@dataclass +class Pattern: + """Base implementation of a pattern over a sequence with multiple codebooks. + + The codebook pattern consists in a layout, defining for each sequence step + the list of coordinates of each codebook timestep in the resulting interleaved sequence. + The first item of the pattern is always an empty list in order to properly insert a special token + to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern + and ``timesteps`` the number of timesteps corresponding to the original sequence. + + The pattern provides convenient methods to build and revert interleaved sequences from it: + ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] + to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size, + K being the number of codebooks, T the number of original timesteps and S the number of sequence steps + for the output sequence. The unfilled positions are replaced with a special token and the built sequence + is returned along with a mask indicating valid tokens. + ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment + of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask + to fill and specify invalid positions if needed. + See the dedicated methods for more details. + """ + # Pattern layout, for each sequence step, we have a list of coordinates + # corresponding to the original codebook timestep and position. + # The first list is always an empty list in order to properly insert + # a special token to start with. + layout: PatternLayout + timesteps: int + n_q: int + + def __post_init__(self): + assert len(self.layout) > 0 + self._validate_layout() + self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) + self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) + logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout)) + + def _validate_layout(self): + """Runs checks on the layout to ensure a valid pattern is defined. + A pattern is considered invalid if: + - Multiple timesteps for a same codebook are defined in the same sequence step + - The timesteps for a given codebook are not in ascending order as we advance in the sequence + (this would mean that we have future timesteps before past timesteps). + """ + q_timesteps = {q: 0 for q in range(self.n_q)} + for s, seq_coords in enumerate(self.layout): + if len(seq_coords) > 0: + qs = set() + for coord in seq_coords: + qs.add(coord.q) + last_q_timestep = q_timesteps[coord.q] + assert coord.t >= last_q_timestep, \ + f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" + q_timesteps[coord.q] = coord.t + # each sequence step contains at max 1 coordinate per codebook + assert len(qs) == len(seq_coords), \ + f"Multiple entries for a same codebook are found at step {s}" + + @property + def num_sequence_steps(self): + return len(self.layout) - 1 + + @property + def max_delay(self): + max_t_in_seq_coords = 0 + for seq_coords in self.layout[1:]: + for coords in seq_coords: + max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1) + return max_t_in_seq_coords - self.timesteps + + @property + def valid_layout(self): + valid_step = len(self.layout) - self.max_delay + return self.layout[:valid_step] + + def starts_with_special_token(self): + return self.layout[0] == [] + + def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): + """Get codebook coordinates in the layout that corresponds to the specified timestep t + and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step + and the actual codebook coordinates. + """ + assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" + if q is not None: + assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" + coords = [] + for s, seq_codes in enumerate(self.layout): + for code in seq_codes: + if code.t == t and (q is None or code.q == q): + coords.append((s, code)) + return coords + + def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]: + return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)] + + def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]: + steps_with_timesteps = self.get_steps_with_timestep(t, q) + return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None + + def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool, + device: tp.Union[torch.device, str] = 'cpu'): + """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. + + Args: + timesteps (int): Maximum number of timesteps steps to consider. + keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. + """ + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" + # use the proper layout based on whether we limit ourselves to valid steps only or not, + # note that using the valid_layout will result in a truncated sequence up to the valid steps + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy() + mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + # the last value is n_q * timesteps as we have flattened z and append special token as the last token + # which will correspond to the index: n_q * timesteps + indexes[:] = n_q * timesteps + # iterate over the pattern and fill scattered indexes and mask + for s, sequence_coords in enumerate(ref_layout): + for coords in sequence_coords: + if coords.t < timesteps: + indexes[coords.q, s] = coords.t + coords.q * timesteps + mask[coords.q, s] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Build sequence corresponding to the pattern from the input tensor z. + The sequence is built using up to sequence_steps if specified, and non-pattern + coordinates are filled with the special token. + + Args: + z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T]. + special_token (int): Special token used to fill non-pattern coordinates in the new sequence. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S + corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S]. + """ + B, K, T = z.shape + indexes, mask = self._build_pattern_sequence_scatter_indexes( + T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device) + ) + z = z.view(B, -1) + # we append the special token as the last index of our flattened z tensor + z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1) + values = z[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int, + keep_only_valid_steps: bool = False, + is_model_output: bool = False, + device: tp.Union[torch.device, str] = 'cpu'): + """Builds scatter indexes required to retrieve the original multi-codebook sequence + from interleaving pattern. + + Args: + sequence_steps (int): Sequence steps. + n_q (int): Number of codebooks. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # TODO(jade): Do we want to further truncate to only valid timesteps here as well? + timesteps = self.timesteps + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert sequence_steps <= len(ref_layout), \ + f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" + + # ensure we take the appropriate indexes to keep the model output from the first special token as well + if is_model_output and self.starts_with_special_token(): + ref_layout = ref_layout[1:] + + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy() + mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + indexes[:] = n_q * sequence_steps + for s, sequence_codes in enumerate(ref_layout): + if s < sequence_steps: + for code in sequence_codes: + if code.t < timesteps: + indexes[code.q, code.t] = s + code.q * sequence_steps + mask[code.q, code.t] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. + The sequence is reverted using up to timesteps if specified, and non-pattern coordinates + are filled with the special token. + + Args: + s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S]. + special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T + corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + B, K, S = s.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device) + ) + s = s.view(B, -1) + # we append the special token as the last index of our flattened z tensor + s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1) + values = s[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False): + """Revert model logits obtained on a sequence built from the pattern + back to a tensor matching the original sequence. + + This method is similar to ``revert_pattern_sequence`` with the following specificities: + 1. It is designed to work with the extra cardinality dimension + 2. We return the logits for the first sequence item that matches the special_token and + which matching target in the original sequence is the first item of the sequence, + while we skip the last logits as there is no matching target + """ + B, card, K, S = logits.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=True, device=logits.device + ) + logits = logits.reshape(B, card, -1) + # we append the special token as the last index of our flattened z tensor + logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S] + values = logits[:, :, indexes.view(-1)] + values = values.view(B, card, K, indexes.shape[-1]) + return values, indexes, mask + + +class CodebooksPatternProvider(ABC): + """Abstraction around providing pattern for interleaving codebooks. + + The CodebooksPatternProvider abstraction allows to implement various strategies to + define interleaving pattern of sequences composed of multiple codebooks. For a given + number of codebooks `n_q`, the pattern provider can generate a specified pattern + corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern + can be used to construct a new sequence from the original codes respecting the specified + pattern. The pattern is defined as a list of list of code coordinates, code coordinate + being a tuple with the original timestep and codebook to build the new sequence. + Note that all patterns must start with an empty list that is then used to insert a first + sequence step of special tokens in the newly generated sequence. + + Args: + n_q (int): number of codebooks. + cached (bool): if True, patterns for a given length are cached. In general + that should be true for efficiency reason to avoid synchronization points. + """ + def __init__(self, n_q: int, cached: bool = True): + assert n_q > 0 + self.n_q = n_q + self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore + + @abstractmethod + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern with specific interleaving between codebooks. + + Args: + timesteps (int): Total number of timesteps. + """ + raise NotImplementedError() + + +class DelayedPatternProvider(CodebooksPatternProvider): + """Provider for delayed pattern across delayed codebooks. + Codebooks are delayed in the sequence and sequence steps will contain codebooks + from different timesteps. + + Example: + Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + The resulting sequence obtained from the returned pattern is: + [[S, 1, 2, 3, 4], + [S, S, 1, 2, 3], + [S, S, S, 1, 2]] + (with S being a special token) + + Args: + n_q (int): Number of codebooks. + delays (list of int, optional): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + flatten_first (int): Flatten the first N timesteps. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, + flatten_first: int = 0, empty_initial: int = 0): + super().__init__(n_q) + if delays is None: + delays = list(range(n_q)) + self.delays = delays + self.flatten_first = flatten_first + self.empty_initial = empty_initial + assert len(self.delays) == self.n_q + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + omit_special_token = self.empty_initial < 0 + out: PatternLayout = [] if omit_special_token else [[]] + max_delay = max(self.delays) + if self.empty_initial: + out += [[] for _ in range(self.empty_initial)] + if self.flatten_first: + for t in range(min(timesteps, self.flatten_first)): + for q in range(self.n_q): + out.append([LayoutCoord(t, q)]) + for t in range(self.flatten_first, timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= self.flatten_first: + v.append(LayoutCoord(t_for_q, q)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class ParallelPatternProvider(DelayedPatternProvider): + """Provider for parallel pattern across codebooks. + This pattern provider is a special case of the delayed pattern with actually no delay, + hence delays=repeat(0, n_q). + + Args: + n_q (int): Number of codebooks. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, empty_initial: int = 0): + super().__init__(n_q, [0] * n_q, empty_initial=empty_initial) + + +class UnrolledPatternProvider(CodebooksPatternProvider): + """Provider for unrolling codebooks pattern. + This pattern provider enables to represent the codebook flattened completely or only to some extend + while also specifying a given delay between the flattened codebooks representation, allowing to + unroll the codebooks in the sequence. + + Example: + 1. Flattening of the codebooks. + By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), + taking n_q = 3 and timesteps = 4: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, 1, S, S, 2, S, S, 3, S, S, 4], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step + for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example + taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks + allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the + same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] + and delays = [0, 3, 3]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, S, 1, S, 2, S, 3, S, 4], + [S, S, S, 1, S, 2, S, 3, S, 4], + [1, 2, 3, S, 4, S, 5, S, 6, S]] + + Args: + n_q (int): Number of codebooks. + flattening (list of int, optional): Flattening schema over the codebooks. If not defined, + the codebooks will be flattened to 1 codebook per step, meaning that the sequence will + have n_q extra steps for each timestep. + delays (list of int, optional): Delay for each of the codebooks. If not defined, + no delay is added and therefore will default to [0] * ``n_q``. + Note that two codebooks that will be flattened to the same inner step + should have the same delay, otherwise the pattern is considered as invalid. + """ + FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay']) + + def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, + delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if flattening is None: + flattening = list(range(n_q)) + if delays is None: + delays = [0] * n_q + assert len(flattening) == n_q + assert len(delays) == n_q + assert sorted(flattening) == flattening + assert sorted(delays) == delays + self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) + self.max_delay = max(delays) + + def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]): + """Build a flattened codebooks representation as a dictionary of inner step + and the actual codebook indices corresponding to the flattened codebook. For convenience, we + also store the delay associated to the flattened codebook to avoid maintaining an extra mapping. + """ + flattened_codebooks: dict = {} + for q, (inner_step, delay) in enumerate(zip(flattening, delays)): + if inner_step not in flattened_codebooks: + flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) + else: + flat_codebook = flattened_codebooks[inner_step] + assert flat_codebook.delay == delay, ( + "Delay and flattening between codebooks is inconsistent: ", + "two codebooks flattened to the same position should have the same delay." + ) + flat_codebook.codebooks.append(q) + flattened_codebooks[inner_step] = flat_codebook + return flattened_codebooks + + @property + def _num_inner_steps(self): + """Number of inner steps to unroll between timesteps in order to flatten the codebooks. + """ + return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1 + + def num_virtual_steps(self, timesteps: int) -> int: + return timesteps * self._num_inner_steps + 1 + + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern for delay across codebooks. + + Args: + timesteps (int): Total number of timesteps. + """ + # the PatternLayout is built as a tuple of sequence position and list of coordinates + # so that it can be reordered properly given the required delay between codebooks of given timesteps + indexed_out: list = [(-1, [])] + max_timesteps = timesteps + self.max_delay + for t in range(max_timesteps): + # for each timestep, we unroll the flattened codebooks, + # emitting the sequence step with the corresponding delay + for step in range(self._num_inner_steps): + if step in self._flattened_codebooks: + # we have codebooks at this virtual step to emit + step_codebooks = self._flattened_codebooks[step] + t_for_q = t + step_codebooks.delay + coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks] + if t_for_q < max_timesteps and t < max_timesteps: + indexed_out.append((t_for_q, coords)) + else: + # there is no codebook in this virtual step so we emit an empty list + indexed_out.append((t, [])) + out = [coords for _, coords in sorted(indexed_out)] + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class CoarseFirstPattern(CodebooksPatternProvider): + """First generates all the codebooks #1 (e.g. coarser), then the remaining ones, + potentially with delays. + + ..Warning:: You must always generate the full training duration at test time, for instance, + 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected + location. This is due to the non causality of the remaining codebooks with respect to + the first ones. + + Args: + n_q (int): Number of codebooks. + delays (list of int, optional): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if delays is None: + delays = [0] * (n_q - 1) + self.delays = delays + assert len(self.delays) == self.n_q - 1 + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for t in range(timesteps): + out.append([LayoutCoord(t, 0)]) + max_delay = max(self.delays) + for t in range(timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= 0: + v.append(LayoutCoord(t_for_q, q + 1)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class MusicLMPattern(CodebooksPatternProvider): + """Almost MusicLM style pattern. This is equivalent to full flattening + but in a different order. + + Args: + n_q (int): Number of codebooks. + group_by (int): Number of codebooks to group together. + """ + def __init__(self, n_q: int, group_by: int = 2): + super().__init__(n_q) + self.group_by = group_by + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for offset in range(0, self.n_q, self.group_by): + for t in range(timesteps): + for q in range(offset, offset + self.group_by): + out.append([LayoutCoord(t, q)]) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) \ No newline at end of file diff --git a/stable_audio_tools/models/conditioners.py b/stable_audio_tools/models/conditioners.py new file mode 100644 index 0000000..916cd94 --- /dev/null +++ b/stable_audio_tools/models/conditioners.py @@ -0,0 +1,710 @@ +#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py + +import torch +import logging, warnings +import string +import typing as tp +import gc + +from .adp import NumberEmbedder +from ..inference.utils import set_audio_channels +from .factory import create_pretransform_from_config +from .pretransforms import Pretransform +from .utils import load_ckpt_state_dict + +from torch import nn +from transformers import AutoProcessor, CLIPVisionModelWithProjection +import einops +from .temptransformer import SA_Transformer +from torchvision import transforms +import torch +import einops +import torchvision.transforms as transforms + + +class Conditioner(nn.Module): + def __init__( + self, + dim: int, + output_dim: int, + project_out: bool = False + ): + + super().__init__() + + self.dim = dim + self.output_dim = output_dim + self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() + + def forward(self, x: tp.Any) -> tp.Any: + raise NotImplementedError() + +class IntConditioner(Conditioner): + def __init__(self, + output_dim: int, + min_val: int=0, + max_val: int=512 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True) + + def forward(self, ints: tp.List[int], device=None) -> tp.Any: + + #self.int_embedder.to(device) + + ints = torch.tensor(ints).to(device) + ints = ints.clamp(self.min_val, self.max_val) + + int_embeds = self.int_embedder(ints).unsqueeze(1) + + return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)] + +class NumberConditioner(Conditioner): + ''' + Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings + ''' + def __init__(self, + output_dim: int, + min_val: float=0, + max_val: float=1 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + + self.embedder = NumberEmbedder(features=output_dim) + + def forward(self, floats: tp.List[float], device=None) -> tp.Any: + + # Cast the inputs to floats + floats = [float(x) for x in floats] + + floats = torch.tensor(floats).to(device) + + floats = floats.clamp(self.min_val, self.max_val) + + normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) + + # Cast floats to same type as embedder + embedder_dtype = next(self.embedder.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + float_embeds = self.embedder(normalized_floats).unsqueeze(1) + + return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] + +class CLAPTextConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + use_text_features = False, + feature_layer_ix: int = -1, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False, + finetune: bool = False): + super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out) + + self.use_text_features = use_text_features + self.feature_layer_ix = feature_layer_ix + self.finetune = finetune + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + if self.finetune: + self.model = model + else: + self.__dict__["model"] = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + if self.finetune: + self.model.model.text_branch.requires_grad_(True) + self.model.model.text_branch.train() + else: + self.model.model.text_branch.requires_grad_(False) + self.model.model.text_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.audio_branch + + gc.collect() + torch.cuda.empty_cache() + + def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): + prompt_tokens = self.model.tokenizer(prompts) + attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True) + prompt_features = self.model.model.text_branch( + input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True), + attention_mask=attention_mask, + output_hidden_states=True + )["hidden_states"][layer_ix] + + return prompt_features, attention_mask + + def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: + self.model.to(device) + + if self.use_text_features: + if len(texts) == 1: + text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device) + text_features = text_features[:1, ...] + text_attention_mask = text_attention_mask[:1, ...] + else: + text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device) + return [self.proj_out(text_features), text_attention_mask] + + # Fix for CLAP bug when only one text is passed + if len(texts) == 1: + text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...] + else: + text_embedding = self.model.get_text_embedding(texts, use_tensor=True) + + text_embedding = text_embedding.unsqueeze(1).to(device) + + return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)] + +class CLAPAudioConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False): + super().__init__(512, output_dim, project_out=project_out) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + if self.finetune: + self.model = model + else: + self.__dict__["model"] = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + if self.finetune: + self.model.model.audio_branch.requires_grad_(True) + self.model.model.audio_branch.train() + else: + self.model.model.audio_branch.requires_grad_(False) + self.model.model.audio_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.text_branch + + gc.collect() + torch.cuda.empty_cache() + + def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any: + + self.model.to(device) + + if isinstance(audios, list) or isinstance(audios, tuple): + audios = torch.cat(audios, dim=0) + + # Convert to mono + mono_audios = audios.mean(dim=1) + + with torch.cuda.amp.autocast(enabled=False): + audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True) + + audio_embedding = audio_embedding.unsqueeze(1).to(device) + + return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] + + +class CLIPConditioner(Conditioner): + CLIP_MODELS = ["clip-vit-base-patch32"] + + def __init__( + self, + output_dim: int, + clip_model_name: str = "clip-vit-base-patch32", + video_fps: int = 5, + out_features: str = 128, + enable_grad: bool = False, + in_features: int = 5000, + project_out: bool = False, + ): + assert clip_model_name in self.CLIP_MODELS, f"Unknown clip model name: {clip_model_name}" + super().__init__(dim = 768, output_dim=output_dim, project_out=project_out) + + sa_depth=4 + num_heads=16 + dim_head=64 + hidden_scale=4 + duration = 10 + + self.clip_model_name=clip_model_name + + if self.clip_model_name=='clip-vit-base-patch32': + out_features = 128 + temporal_dim=768 + + self.empty_visual_feat = nn.Parameter(torch.zeros(1, out_features, temporal_dim), requires_grad=True) + nn.init.constant_(self.empty_visual_feat, 0) + + in_features = 50*video_fps*duration + + self.visual_encoder_model = CLIPVisionModelWithProjection.from_pretrained('openai/clip-vit-base-patch32') + self.proj = nn.Linear(in_features=in_features, out_features=out_features) + + self.in_features = in_features + self.out_features = out_features + + self.Temp_transformer = SA_Transformer(temporal_dim, sa_depth, num_heads, dim_head, temporal_dim*hidden_scale, 0.) + self.Temp_pos_embedding = nn.Parameter(torch.randn(1, duration*video_fps, temporal_dim)) + + clip_mean = [0.48145466, 0.4578275, 0.40821073] + clip_std = [0.26862954, 0.26130258, 0.27577711] + self.preprocess_CLIP = transforms.Compose([ + transforms.Normalize(mean=clip_mean, std=clip_std) + ]) + + def process_video_with_custom_preprocessing(self, video_tensor): + video_tensor = video_tensor / 255.0 + video_tensor = self.preprocess_CLIP(video_tensor) + return video_tensor + + def init_first_from_ckpt(self, path): + model = torch.load(path, map_location="cpu") + if "state_dict" in list(model.keys()): + model = model["state_dict"] + # Remove: module prefix + new_model = {} + for key in model.keys(): + new_key = key.replace("module.","") + new_model[new_key] = model[key] + missing, unexpected = self.visual_encoder_model.load_state_dict(new_model, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def forward(self, Video_tensors: tp.List[torch.Tensor], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + visual_encoder_model = self.visual_encoder_model.eval().to(device) + proj = self.proj.to(device) + + original_videos = torch.cat(Video_tensors, dim=0).to(device) + batch_size, time_length, _, _, _ = original_videos.size() + is_zero = torch.all(original_videos == 0, dim=(1,2,3,4)) + Video_tensors = original_videos + Video_tensors = einops.rearrange(Video_tensors, 'b t c h w -> (b t) c h w') + + video_cond_pixel_values = self.process_video_with_custom_preprocessing(video_tensor=Video_tensors.to(device)).to(device) + if self.clip_model_name=='clip-vit-base-patch32': + with torch.no_grad(): + outputs = visual_encoder_model(pixel_values=video_cond_pixel_values) + video_hidden = outputs.last_hidden_state + + video_hidden = einops.rearrange(video_hidden, '(b t) q h -> (b q) t h',b=batch_size,t=time_length) + video_hidden += self.Temp_pos_embedding + video_hidden = self.Temp_transformer(video_hidden) + video_hidden = einops.rearrange(video_hidden, '(b q) t h -> b (t q) h',b=batch_size,t=time_length) + + video_hidden = proj(video_hidden.view(-1, self.in_features)) + video_hidden = video_hidden.view(batch_size, self.out_features, -1) + + empty_visual_feat = self.empty_visual_feat.expand(batch_size, -1, -1) + is_zero_expanded = is_zero.view(batch_size, 1, 1) + video_hidden = torch.where(is_zero_expanded, empty_visual_feat, video_hidden) + + return video_hidden, torch.ones(video_hidden.shape[0], 1).to(device) + + + +class T5Conditioner(Conditioner): + + T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl"] + + T5_MODEL_DIMS = { + "t5-small": 512, + "t5-base": 768, + "t5-large": 1024, + "t5-3b": 1024, + "t5-11b": 1024, + "t5-xl": 2048, + "t5-xxl": 4096, + "google/flan-t5-small": 512, + "google/flan-t5-base": 768, + "google/flan-t5-large": 1024, + "google/flan-t5-3b": 1024, + "google/flan-t5-11b": 1024, + "google/flan-t5-xl": 2048, + "google/flan-t5-xxl": 4096, + } + + def __init__( + self, + output_dim: int, + t5_model_name: str = "t5-base", + max_length: str = 128, + enable_grad: bool = False, + project_out: bool = False, + ): + assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" + super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) + + from transformers import T5EncoderModel, AutoTokenizer + + self.max_length = max_length + self.enable_grad = enable_grad + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) + model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + finally: + logging.disable(previous_level) + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.cuda.amp.autocast(dtype=torch.float16), torch.set_grad_enabled(self.enable_grad): + embeddings = self.model( + input_ids=input_ids, attention_mask=attention_mask + )["last_hidden_state"] + + embeddings = self.proj_out(embeddings.float()) + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +class PhonemeConditioner(Conditioner): + """ + A conditioner that turns text into phonemes and embeds them using a lookup table + Only works for English text + + Args: + output_dim: the dimension of the output embeddings + max_length: the maximum number of phonemes to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from g2p_en import G2p + self.max_length = max_length + self.g2p = G2p() + # Reserving 0 for padding, 1 for ignored + self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.phoneme_embedder.to(device) + self.proj_out.to(device) + + batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length] + phoneme_ignore = [" ", *string.punctuation] + # Remove ignored phonemes and cut to max length + batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes] + + # Convert to ids + phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes] + + #Pad to match longest and make a mask tensor for the padding + longest = max([len(ids) for ids in phoneme_ids]) + phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids] + phoneme_ids = torch.tensor(phoneme_ids).to(device) + + # Convert to embeddings + phoneme_embeds = self.phoneme_embedder(phoneme_ids) + phoneme_embeds = self.proj_out(phoneme_embeds) + + return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device) + + + +class TokenizerLUTConditioner(Conditioner): + """ + A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary + + Args: + tokenizer_name: the name of the tokenizer from the Hugging Face transformers library + output_dim: the dimension of the output embeddings + max_length: the maximum length of the text to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from transformers import AutoTokenizer + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + finally: + logging.disable(previous_level) + + self.max_length = max_length + + self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + self.proj_out.to(device) + + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + embeddings = self.token_embedder(input_ids) + + embeddings = self.proj_out(embeddings) + + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +class PretransformConditioner(Conditioner): + """ + A conditioner that uses a pretransform's encoder for conditioning + + Args: + pretransform: an instantiated pretransform to use for conditioning + output_dim: the dimension of the output embeddings + """ + def __init__(self, pretransform: Pretransform, output_dim: int): + super().__init__(pretransform.encoded_channels, output_dim) + + self.pretransform = pretransform + + def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.pretransform.to(device) + self.proj_out.to(device) + + if isinstance(audio, list) or isinstance(audio, tuple): + audio = torch.cat(audio, dim=0) + + # Convert audio to pretransform input channels + audio = set_audio_channels(audio, self.pretransform.io_channels) + + latents = self.pretransform.encode(audio) + latents = self.proj_out(latents) + + return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)] + + +class AudioAutoencoderConditioner(Conditioner): + """ + A conditioner that uses a pretransform's encoder for conditioning + + Args: + pretransform: an instantiated pretransform to use for conditioning + output_dim: the dimension of the output embeddings + """ + def __init__(self, pretransform: Pretransform, output_dim: int): + super().__init__(pretransform.encoded_channels, output_dim) + + self.pretransform = pretransform + self.empty_audio_feat = nn.Parameter(torch.zeros(1, 215, self.proj_out.out_features), requires_grad=True) + nn.init.constant_(self.empty_audio_feat, 0) + + def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.pretransform.to(device) + self.proj_out.to(device) + + if isinstance(audio, list) or isinstance(audio, tuple): + original_audios = torch.cat(audio, dim=0).to(device) + is_zero = torch.all(original_audios == 0, dim=(1,2)) + audio = original_audios + + # Convert audio to pretransform input channels + audio = set_audio_channels(audio, self.pretransform.io_channels) + + latents = self.pretransform.encode(audio) + latents = latents.permute(0, 2, 1) + latents = self.proj_out(latents) + + empty_audio_feat = self.empty_audio_feat.expand(latents.shape[0], -1, -1) + is_zero_expanded = is_zero.view(latents.shape[0], 1, 1) + latents = torch.where(is_zero_expanded, empty_audio_feat, latents) + return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)] + + +class MultiConditioner(nn.Module): + """ + A module that applies multiple conditioners to an input dictionary based on the keys + + Args: + conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt") + default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"}) + """ + def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}): + super().__init__() + + self.conditioners = nn.ModuleDict(conditioners) + self.default_keys = default_keys + + def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]: + output = {} + + for key, conditioner in self.conditioners.items(): + condition_key = key + + conditioner_inputs = [] + + for x in batch_metadata: + + if condition_key not in x: + if condition_key in self.default_keys: + condition_key = self.default_keys[condition_key] + else: + raise ValueError(f"Conditioner key {condition_key} not found in batch metadata") + + if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1: + conditioner_input = x[condition_key][0] + + else: + conditioner_input = x[condition_key] + + conditioner_inputs.append(conditioner_input) + + output[key] = conditioner(conditioner_inputs, device) + + return output + +def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner: + """ + Create a MultiConditioner from a conditioning config dictionary + + Args: + config: the conditioning config dictionary + device: the device to put the conditioners on + """ + conditioners = {} + cond_dim = config["cond_dim"] + + default_keys = config.get("default_keys", {}) + + for conditioner_info in config["configs"]: + id = conditioner_info["id"] + + conditioner_type = conditioner_info["type"] + + conditioner_config = {"output_dim": cond_dim} + + conditioner_config.update(conditioner_info["config"]) + + if conditioner_type == "t5": + conditioners[id] = T5Conditioner(**conditioner_config) + elif conditioner_type == "clip": + conditioners[id] = CLIPConditioner(**conditioner_config) + elif conditioner_type == "clap_text": + conditioners[id] = CLAPTextConditioner(**conditioner_config) + elif conditioner_type == "clap_audio": + conditioners[id] = CLAPAudioConditioner(**conditioner_config) + elif conditioner_type == "int": + conditioners[id] = IntConditioner(**conditioner_config) + elif conditioner_type == "number": + conditioners[id] = NumberConditioner(**conditioner_config) + elif conditioner_type == "phoneme": + conditioners[id] = PhonemeConditioner(**conditioner_config) + elif conditioner_type == "lut": + conditioners[id] = TokenizerLUTConditioner(**conditioner_config) + elif conditioner_type == "pretransform": + sample_rate = conditioner_config.pop("sample_rate", None) + assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" + + pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) + + if conditioner_config.get("pretransform_ckpt_path", None) is not None: + pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) + + conditioners[id] = PretransformConditioner(pretransform, **conditioner_config) + + elif conditioner_type == "audio_autoencoder": + sample_rate = conditioner_config.pop("sample_rate", None) + assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" + + pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) + + if conditioner_config.get("pretransform_ckpt_path", None) is not None: + pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) + + conditioners[id] = AudioAutoencoderConditioner(pretransform, **conditioner_config) + else: + raise ValueError(f"Unknown conditioner type: {conditioner_type}") + + return MultiConditioner(conditioners, default_keys=default_keys) \ No newline at end of file diff --git a/stable_audio_tools/models/diffusion.py b/stable_audio_tools/models/diffusion.py new file mode 100644 index 0000000..00c8c1d --- /dev/null +++ b/stable_audio_tools/models/diffusion.py @@ -0,0 +1,704 @@ +import torch +from torch import nn +from torch.nn import functional as F +from functools import partial +import numpy as np +import typing as tp + +from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .dit import DiffusionTransformer +from .factory import create_pretransform_from_config +from .pretransforms import Pretransform +from ..inference.generation import generate_diffusion_cond + +from .adp import UNetCFG1d, UNet1d + +from time import time + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + +class DiffusionModel(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x, t, **kwargs): + raise NotImplementedError() + +class DiffusionModelWrapper(nn.Module): + def __init__( + self, + model: DiffusionModel, + io_channels, + sample_size, + sample_rate, + min_input_length, + pretransform: tp.Optional[Pretransform] = None, + ): + super().__init__() + self.io_channels = io_channels + self.sample_size = sample_size + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.model = model + + if pretransform is not None: + self.pretransform = pretransform + else: + self.pretransform = None + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class ConditionedDiffusionModel(nn.Module): + def __init__(self, + *args, + supports_cross_attention: bool = False, + supports_input_concat: bool = False, + supports_global_cond: bool = False, + supports_prepend_cond: bool = False, + **kwargs): + super().__init__(*args, **kwargs) + self.supports_cross_attention = supports_cross_attention + self.supports_input_concat = supports_input_concat + self.supports_global_cond = supports_global_cond + self.supports_prepend_cond = supports_prepend_cond + + def forward(self, + x: torch.Tensor, + t: torch.Tensor, + cross_attn_cond: torch.Tensor = None, + cross_attn_mask: torch.Tensor = None, + input_concat_cond: torch.Tensor = None, + global_embed: torch.Tensor = None, + prepend_cond: torch.Tensor = None, + prepend_cond_mask: torch.Tensor = None, + cfg_scale: float = 1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + **kwargs): + raise NotImplementedError() + +class ConditionedDiffusionModelWrapper(nn.Module): + """ + A diffusion model that takes in conditioning + """ + def __init__( + self, + model: ConditionedDiffusionModel, + conditioner: MultiConditioner, + io_channels, + sample_rate, + min_input_length: int, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + pretransform: tp.Optional[Pretransform] = None, + cross_attn_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + input_concat_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + ): + super().__init__() + + self.model = model + self.conditioner = conditioner + self.io_channels = io_channels + self.sample_rate = sample_rate + self.diffusion_objective = diffusion_objective + self.pretransform = pretransform + self.cross_attn_cond_ids = cross_attn_cond_ids # ['prompt', 'seconds_start', 'seconds_total'] + self.global_cond_ids = global_cond_ids # ['seconds_start', 'seconds_total'] + self.input_concat_ids = input_concat_ids + self.prepend_cond_ids = prepend_cond_ids + self.min_input_length = min_input_length + + def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[torch.Tensor, tp.Any], negative=False): + cross_attention_input = None + cross_attention_masks = None + global_cond = None + input_concat_cond = None + prepend_cond = None + prepend_cond_mask = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = [] + cross_attention_masks = [] + + for key in self.cross_attn_cond_ids: + cross_attn_in, cross_attn_mask = conditioning_tensors[key] + + # Add sequence dimension if it's not there + if len(cross_attn_in.shape) == 2: + cross_attn_in = cross_attn_in.unsqueeze(1) + cross_attn_mask = cross_attn_mask.unsqueeze(1) + + cross_attention_input.append(cross_attn_in) + cross_attention_masks.append(cross_attn_mask) + + cross_attention_input = torch.cat(cross_attention_input, dim=1) # [1, 130, 768] (text feature:128) + cross_attention_masks = torch.cat(cross_attention_masks, dim=1) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_conds = [] + for key in self.global_cond_ids: + + global_cond_input = conditioning_tensors[key][0] + + global_conds.append(global_cond_input) + + # Concatenate over the channel dimension + global_cond = torch.cat(global_conds, dim=-1) + + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if len(self.input_concat_ids) > 0: # False + # Concatenate all input concat conditioning inputs over the channel dimension + # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq) + input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: # False + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_conds = [] + prepend_cond_masks = [] + + for key in self.prepend_cond_ids: + prepend_cond_input, prepend_cond_mask = conditioning_tensors[key] + prepend_conds.append(prepend_cond_input) + prepend_cond_masks.append(prepend_cond_mask) + + prepend_cond = torch.cat(prepend_conds, dim=1) + prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1) + + if negative: # False + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_cross_attn_mask": cross_attention_masks, + "negative_global_cond": global_cond, + "negative_input_concat_cond": input_concat_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "cross_attn_mask": cross_attention_masks, + "global_cond": global_cond, + "input_concat_cond": input_concat_cond, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask + } + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): + return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs) + + def generate(self, *args, **kwargs): + return generate_diffusion_cond(self, *args, **kwargs) + +class UNetCFG1DWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True) + + self.model = UNetCFG1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + input_concat_cond=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + **kwargs): + p = Profiler() + + p.tick("start") + + channels_list = None + if input_concat_cond is not None: + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + embedding=cross_attn_cond, + embedding_mask=cross_attn_mask, + features=global_cond, + channels_list=channels_list, + embedding_scale=cfg_scale, + embedding_mask_proba=cfg_dropout_prob, + batch_cfg=batch_cfg, + rescale_cfg=rescale_cfg, + negative_embedding=negative_cross_attn_cond, + negative_embedding_mask=negative_cross_attn_mask, + **kwargs) + + p.tick("UNetCFG1D forward") + + #print(f"Profiler: {p}") + return outputs + +class UNet1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True) + + self.model = UNet1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + global_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + **kwargs): + + channels_list = None + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + features=global_cond, + channels_list=channels_list, + **kwargs) + + return outputs + +class UNet1DUncondWrapper(DiffusionModel): + def __init__( + self, + in_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = UNet1d(in_channels=in_channels, *args, **kwargs) + + self.io_channels = in_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class DAU1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True) + + self.model = DiffusionAttnUnet1D(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + **kwargs): + + return self.model(x, t, cond = input_concat_cond) + +class DiffusionAttnUnet1D(nn.Module): + def __init__( + self, + io_channels = 2, + depth=14, + n_attn_layers = 6, + channels = [128, 128, 256, 256] + [512] * 10, + cond_dim = 0, + cond_noise_aug = False, + kernel_size = 5, + learned_resample = False, + strides = [2] * 13, + conv_bias = True, + use_snake = False + ): + super().__init__() + + self.cond_noise_aug = cond_noise_aug + + self.io_channels = io_channels + + if self.cond_noise_aug: + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_embed = FourierFeatures(1, 16) + + attn_layer = depth - n_attn_layers + + strides = [1] + strides + + block = nn.Identity() + + conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake) + + for i in range(depth, 0, -1): + c = channels[i - 1] + stride = strides[i-1] + if stride > 2 and not learned_resample: + raise ValueError("Must have stride 2 without learned resampling") + + if i > 1: + c_prev = channels[i - 2] + add_attn = i >= attn_layer and n_attn_layers > 0 + block = SkipBlock( + Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"), + conv_block(c_prev, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + block, + conv_block(c * 2 if i != depth else c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c_prev), + SelfAttention1d(c_prev, c_prev // + 32) if add_attn else nn.Identity(), + Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic") + ) + else: + cond_embed_dim = 16 if not self.cond_noise_aug else 32 + block = nn.Sequential( + conv_block((io_channels + cond_dim) + cond_embed_dim, c, c), + conv_block(c, c, c), + conv_block(c, c, c), + block, + conv_block(c * 2, c, c), + conv_block(c, c, c), + conv_block(c, c, io_channels, is_last=True), + ) + self.net = block + + with torch.no_grad(): + for param in self.net.parameters(): + param *= 0.5 + + def forward(self, x, t, cond=None, cond_aug_scale=None): + + timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape) + + inputs = [x, timestep_embed] + + if cond is not None: + if cond.shape[2] != x.shape[2]: + cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False) + + if self.cond_noise_aug: + # Get a random number between 0 and 1, uniformly sampled + if cond_aug_scale is None: + aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond) + else: + aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond) + + # Add noise to the conditioning signal + cond = cond + torch.randn_like(cond) * aug_level[:, None, None] + + # Get embedding for noise cond level, reusing timestamp_embed + aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape) + + inputs.append(aug_level_embed) + + inputs.append(cond) + + outputs = self.net(torch.cat(inputs, dim=1)) + + return outputs + +class DiTWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) + + self.model = DiffusionTransformer(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + negative_input_concat_cond=None, + global_cond=None, + negative_global_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = True, + rescale_cfg: bool = False, + scale_phi: float = 0.0, + **kwargs): + + assert batch_cfg, "batch_cfg must be True for DiTWrapper" + #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" + + return self.model( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_mask, + negative_cross_attn_cond=negative_cross_attn_cond, + negative_cross_attn_mask=negative_cross_attn_mask, + input_concat_cond=input_concat_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + cfg_scale=cfg_scale, + cfg_dropout_prob=cfg_dropout_prob, + scale_phi=scale_phi, + global_embed=global_cond, + **kwargs) + +class DiTUncondWrapper(DiffusionModel): + def __init__( + self, + in_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs) + + self.io_channels = in_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): + diffusion_uncond_config = config["model"] + + model_type = diffusion_uncond_config.get('type', None) + + diffusion_config = diffusion_uncond_config.get('config', {}) + + assert model_type is not None, "Must specify model type in config" + + pretransform = diffusion_uncond_config.get("pretransform", None) + + sample_size = config.get("sample_size", None) + assert sample_size is not None, "Must specify sample size in config" + + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "Must specify sample rate in config" + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if model_type == 'DAU1d': + + model = DiffusionAttnUnet1D( + **diffusion_config + ) + + elif model_type == "adp_uncond_1d": + + model = UNet1DUncondWrapper( + **diffusion_config + ) + + elif model_type == "dit": + model = DiTUncondWrapper( + **diffusion_config + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + + return DiffusionModelWrapper(model, + io_channels=model.io_channels, + sample_size=sample_size, + sample_rate=sample_rate, + pretransform=pretransform, + min_input_length=min_input_length) + +def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): + + model_config = config["model"] + + model_type = config["model_type"] + + diffusion_config = model_config.get('diffusion', None) + assert diffusion_config is not None, "Must specify diffusion config" + + diffusion_model_type = diffusion_config.get('type', None) + assert diffusion_model_type is not None, "Must specify diffusion model type" + + diffusion_model_config = diffusion_config.get('config', None) + if diffusion_model_config.get('video_fps', None) is not None: + diffusion_model_config.pop('video_fps') + assert diffusion_model_config is not None, "Must specify diffusion model config" + + if diffusion_model_type == 'adp_cfg_1d': + diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) + elif diffusion_model_type == 'adp_1d': + diffusion_model = UNet1DCondWrapper(**diffusion_model_config) + elif diffusion_model_type == 'dit': + diffusion_model = DiTWrapper(**diffusion_model_config) + + io_channels = model_config.get('io_channels', None) + assert io_channels is not None, "Must specify io_channels in model config" + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + diffusion_objective = diffusion_config.get('diffusion_objective', 'v') + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) + global_cond_ids = diffusion_config.get('global_cond_ids', []) + input_concat_ids = diffusion_config.get('input_concat_ids', []) + prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) + + pretransform = model_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": + min_input_length *= np.prod(diffusion_model_config["factors"]) + elif diffusion_model_type == "dit": + min_input_length *= diffusion_model.model.patch_size + + # Get the proper wrapper class + + extra_kwargs = {} + + if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint": + wrapper_fn = ConditionedDiffusionModelWrapper + + extra_kwargs["diffusion_objective"] = diffusion_objective + + elif model_type == "diffusion_prior": + prior_type = model_config.get("prior_type", None) + assert prior_type is not None, "Must specify prior_type in diffusion prior model config" + + if prior_type == "mono_stereo": + from .diffusion_prior import MonoToStereoDiffusionPrior + wrapper_fn = MonoToStereoDiffusionPrior + + return wrapper_fn( + diffusion_model, + conditioner, + min_input_length=min_input_length, + sample_rate=sample_rate, + cross_attn_cond_ids=cross_attention_ids, + global_cond_ids=global_cond_ids, + input_concat_ids=input_concat_ids, + prepend_cond_ids=prepend_cond_ids, + pretransform=pretransform, + io_channels=io_channels, + **extra_kwargs + ) \ No newline at end of file diff --git a/stable_audio_tools/models/discriminators.py b/stable_audio_tools/models/discriminators.py new file mode 100644 index 0000000..b593168 --- /dev/null +++ b/stable_audio_tools/models/discriminators.py @@ -0,0 +1,546 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from functools import reduce +import typing as tp +from einops import rearrange +from audiotools import AudioSignal, STFTParams +from dac.model.discriminator import WNConv1d, WNConv2d + +def get_hinge_losses(score_real, score_fake): + gen_loss = -score_fake.mean() + dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean() + return dis_loss, gen_loss + +class EncodecDiscriminator(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + from encodec.msstftd import MultiScaleSTFTDiscriminator + + self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs) + + def forward(self, x): + logits, features = self.discriminators(x) + return logits, features + + def loss(self, x, y): + feature_matching_distance = 0. + logits_true, feature_true = self.forward(x) + logits_fake, feature_fake = self.forward(y) + + dis_loss = torch.tensor(0.) + adv_loss = torch.tensor(0.) + + for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)): + + feature_matching_distance = feature_matching_distance + sum( + map( + lambda x, y: abs(x - y).mean(), + scale_true, + scale_fake, + )) / len(scale_true) + + _dis, _adv = get_hinge_losses( + logits_true[i], + logits_fake[i], + ) + + dis_loss = dis_loss + _dis + adv_loss = adv_loss + _adv + + return dis_loss, adv_loss, feature_matching_distance + +# Discriminators from oobleck + +IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]] + +TensorDict = tp.Dict[str, torch.Tensor] + +class SharedDiscriminatorConvNet(nn.Module): + + def __init__( + self, + in_size: int, + convolution: tp.Union[nn.Conv1d, nn.Conv2d], + out_size: int = 1, + capacity: int = 32, + n_layers: int = 4, + kernel_size: int = 15, + stride: int = 4, + activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(), + normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm, + ) -> None: + super().__init__() + channels = [in_size] + channels += list(capacity * 2**np.arange(n_layers)) + + if isinstance(stride, int): + stride = n_layers * [stride] + + net = [] + for i in range(n_layers): + if isinstance(kernel_size, int): + pad = kernel_size // 2 + s = stride[i] + else: + pad = kernel_size[0] // 2 + s = (stride[i], 1) + + net.append( + normalization( + convolution( + channels[i], + channels[i + 1], + kernel_size, + stride=s, + padding=pad, + ))) + net.append(activation()) + + net.append(convolution(channels[-1], out_size, 1)) + + self.net = nn.ModuleList(net) + + def forward(self, x) -> IndividualDiscriminatorOut: + features = [] + for layer in self.net: + x = layer(x) + if isinstance(layer, nn.modules.conv._ConvNd): + features.append(x) + score = x.reshape(x.shape[0], -1).mean(-1) + return score, features + + +class MultiScaleDiscriminator(nn.Module): + + def __init__(self, + in_channels: int, + n_scales: int, + **conv_kwargs) -> None: + super().__init__() + layers = [] + for _ in range(n_scales): + layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs)) + self.layers = nn.ModuleList(layers) + + def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: + score = 0 + features = [] + for layer in self.layers: + s, f = layer(x) + score = score + s + features.extend(f) + x = nn.functional.avg_pool1d(x, 2) + return score, features + +class MultiPeriodDiscriminator(nn.Module): + + def __init__(self, + in_channels: int, + periods: tp.Sequence[int], + **conv_kwargs) -> None: + super().__init__() + layers = [] + self.periods = periods + + for _ in periods: + layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs)) + + self.layers = nn.ModuleList(layers) + + def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: + score = 0 + features = [] + for layer, n in zip(self.layers, self.periods): + s, f = layer(self.fold(x, n)) + score = score + s + features.extend(f) + return score, features + + def fold(self, x: torch.Tensor, n: int) -> torch.Tensor: + pad = (n - (x.shape[-1] % n)) % n + x = nn.functional.pad(x, (0, pad)) + return x.reshape(*x.shape[:2], -1, n) + + +class MultiDiscriminator(nn.Module): + """ + Individual discriminators should take a single tensor as input (NxB C T) and + return a tuple composed of a score tensor (NxB) and a Sequence of Features + Sequence[NxB C' T']. + """ + + def __init__(self, discriminator_list: tp.Sequence[nn.Module], + keys: tp.Sequence[str]) -> None: + super().__init__() + self.discriminators = nn.ModuleList(discriminator_list) + self.keys = keys + + def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict: + features = features.chunk(len(self.keys), 0) + return {k: features[i] for i, k in enumerate(self.keys)} + + @staticmethod + def concat_dicts(dict_a, dict_b): + out_dict = {} + keys = set(list(dict_a.keys()) + list(dict_b.keys())) + for k in keys: + out_dict[k] = [] + if k in dict_a: + if isinstance(dict_a[k], list): + out_dict[k].extend(dict_a[k]) + else: + out_dict[k].append(dict_a[k]) + if k in dict_b: + if isinstance(dict_b[k], list): + out_dict[k].extend(dict_b[k]) + else: + out_dict[k].append(dict_b[k]) + return out_dict + + @staticmethod + def sum_dicts(dict_a, dict_b): + out_dict = {} + keys = set(list(dict_a.keys()) + list(dict_b.keys())) + for k in keys: + out_dict[k] = 0. + if k in dict_a: + out_dict[k] = out_dict[k] + dict_a[k] + if k in dict_b: + out_dict[k] = out_dict[k] + dict_b[k] + return out_dict + + def forward(self, inputs: TensorDict) -> TensorDict: + discriminator_input = torch.cat([inputs[k] for k in self.keys], 0) + all_scores = [] + all_features = [] + + for discriminator in self.discriminators: + score, features = discriminator(discriminator_input) + scores = self.unpack_tensor_to_dict(score) + scores = {f"score_{k}": scores[k] for k in scores.keys()} + all_scores.append(scores) + + features = map(self.unpack_tensor_to_dict, features) + features = reduce(self.concat_dicts, features) + features = {f"features_{k}": features[k] for k in features.keys()} + all_features.append(features) + + all_scores = reduce(self.sum_dicts, all_scores) + all_features = reduce(self.concat_dicts, all_features) + + inputs.update(all_scores) + inputs.update(all_features) + + return inputs + +class OobleckDiscriminator(nn.Module): + + def __init__( + self, + in_channels=1, + ): + super().__init__() + + multi_scale_discriminator = MultiScaleDiscriminator( + in_channels=in_channels, + n_scales=3, + ) + + multi_period_discriminator = MultiPeriodDiscriminator( + in_channels=in_channels, + periods=[2, 3, 5, 7, 11] + ) + + # multi_resolution_discriminator = MultiScaleSTFTDiscriminator( + # filters=32, + # in_channels = in_channels, + # out_channels = 1, + # n_ffts = [2048, 1024, 512, 256, 128], + # hop_lengths = [512, 256, 128, 64, 32], + # win_lengths = [2048, 1024, 512, 256, 128] + # ) + + self.multi_discriminator = MultiDiscriminator( + [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator], + ["reals", "fakes"] + ) + + def loss(self, reals, fakes): + inputs = { + "reals": reals, + "fakes": fakes, + } + + inputs = self.multi_discriminator(inputs) + + scores_real = inputs["score_reals"] + scores_fake = inputs["score_fakes"] + + features_real = inputs["features_reals"] + features_fake = inputs["features_fakes"] + + dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake) + + feature_matching_distance = torch.tensor(0.) + + for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)): + + feature_matching_distance = feature_matching_distance + sum( + map( + lambda real, fake: abs(real - fake).mean(), + scale_real, + scale_fake, + )) / len(scale_real) + + return dis_loss, gen_loss, feature_matching_distance + + +## Discriminators from Descript Audio Codec repo +## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt +class MPD(nn.Module): + def __init__(self, period, channels=1): + super().__init__() + + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1): + super().__init__() + + self.convs = nn.ModuleList( + [ + WNConv1d(channels, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + channels: int = 1 + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + self.channels = channels + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels) + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class DACDiscriminator(nn.Module): + def __init__( + self, + channels: int = 1, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p, channels=channels) for p in periods] + discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + +class DACGANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, **discriminator_kwargs): + super().__init__() + self.discriminator = DACDiscriminator(**discriminator_kwargs) + + def forward(self, fake, real): + d_fake = self.discriminator(fake) + d_real = self.discriminator(real) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature + + def loss(self, fake, real): + gen_loss, feature_distance = self.generator_loss(fake, real) + dis_loss = self.discriminator_loss(fake, real) + + return dis_loss, gen_loss, feature_distance \ No newline at end of file diff --git a/stable_audio_tools/models/dit.py b/stable_audio_tools/models/dit.py new file mode 100644 index 0000000..42991c7 --- /dev/null +++ b/stable_audio_tools/models/dit.py @@ -0,0 +1,379 @@ +import typing as tp + +import torch + +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from x_transformers import ContinuousTransformerWrapper, Encoder + +from .blocks import FourierFeatures +from .transformer import ContinuousTransformer + +class DiffusionTransformer(nn.Module): + def __init__(self, + io_channels=32, + patch_size=1, + embed_dim=768, + cond_token_dim=0, + project_cond_tokens=True, + global_cond_dim=0, + project_global_cond=True, + input_concat_dim=0, + prepend_cond_dim=0, + depth=12, + num_heads=8, + transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + **kwargs): + + super().__init__() + + self.cond_token_dim = cond_token_dim + + # Timestep embeddings + timestep_features_dim = 256 + + self.timestep_features = FourierFeatures(1, timestep_features_dim) + + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + + if cond_token_dim > 0: + # Conditioning tokens + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + nn.Linear(cond_token_dim, cond_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) + ) + else: + cond_embed_dim = 0 + + if global_cond_dim > 0: + # Global conditioning + global_embed_dim = global_cond_dim if not project_global_cond else embed_dim + self.to_global_embed = nn.Sequential( + nn.Linear(global_cond_dim, global_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(global_embed_dim, global_embed_dim, bias=False) + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + self.input_concat_dim = input_concat_dim + + dim_in = io_channels + self.input_concat_dim + + self.patch_size = patch_size + + # Transformer + + self.transformer_type = transformer_type + + self.global_cond_type = global_cond_type + + if self.transformer_type == "x-transformers": + self.transformer = ContinuousTransformerWrapper( + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + max_seq_len=0, #Not relevant without absolute positional embeds + attn_layers = Encoder( + dim=embed_dim, + depth=depth, + heads=num_heads, + attn_flash = True, + cross_attend = cond_token_dim > 0, + dim_context=None if cond_embed_dim == 0 else cond_embed_dim, + zero_init_branch_output=True, + use_abs_pos_emb = False, + rotary_pos_emb=True, + ff_swish = True, + ff_glu = True, + **kwargs + ) + ) + + elif self.transformer_type == "continuous_transformer": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend = cond_token_dim > 0, + cond_token_dim = cond_embed_dim, + global_cond_dim=global_dim, + **kwargs + ) + + else: + raise ValueError(f"Unknown transformer type: {self.transformer_type}") + + self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) + nn.init.zeros_(self.preprocess_conv.weight) + self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.postprocess_conv.weight) + + def _forward( + self, + x, + t, + mask=None, + cross_attn_cond=None, + cross_attn_cond_mask=None, + input_concat_cond=None, + global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + return_info=False, + **kwargs): + + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) # MLP endecoder, shape: [1, 130, 768] + + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + x = torch.cat([x, input_concat_cond], dim=1) + + # Get the batch of timestep embeddings + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + if global_embed is not None: + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend": # True + if prepend_inputs is None: # True + # Prepend inputs are just the global embed, and the mask is all ones + prepend_inputs = global_embed.unsqueeze(1) # [1, 1, 1536] + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) + + prepend_length = prepend_inputs.shape[1] # 1 + + x = self.preprocess_conv(x) + x # [1, 64, 1024] + + x = rearrange(x, "b c t -> b t c") # [1, 1024, 64] + + extra_args = {} + + if self.global_cond_type == "adaLN": # 'prepend' + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: # self.patch_size==1 + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if self.transformer_type == "x-transformers": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) + elif self.transformer_type == "continuous_transformer": + + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) + + if return_info: + output, info = output + elif self.transformer_type == "mm_transformer": + output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs) + + output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + + output = self.postprocess_conv(output) + output + + if return_info: + return output, info + + return output + + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_cond_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob=0.0, + causal=False, + scale_phi=0.0, + mask=None, + return_info=False, + **kwargs): + + assert causal == False, "Causal mode is not supported for DiffusionTransformer" + + if cross_attn_cond_mask is not None: + cross_attn_cond_mask = cross_attn_cond_mask.bool() + + cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.bool() + + # CFG dropout + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + + if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None): + # Classifier-free guidance + # Concatenate conditioned and unconditioned inputs on the batch dimension + batch_inputs = torch.cat([x, x], dim=0) + batch_timestep = torch.cat([t, t], dim=0) + + if global_embed is not None: + batch_global_cond = torch.cat([global_embed, global_embed], dim=0) + else: + batch_global_cond = None + + if input_concat_cond is not None: + batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) + else: + batch_input_concat_cond = None + + batch_cond = None + batch_cond_masks = None + + # Handle CFG for cross-attention conditioning + if cross_attn_cond is not None: + + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning + if negative_cross_attn_cond is not None: + + # If there's a negative cross-attention mask, set the masked tokens to the null embed + if negative_cross_attn_mask is not None: + negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + + negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) + + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) + + else: + batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if cross_attn_cond_mask is not None: + batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) + + batch_prepend_cond = None + batch_prepend_cond_mask = None + + if prepend_cond is not None: + + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + + if mask is not None: + batch_masks = torch.cat([mask, mask], dim=0) + else: + batch_masks = None + + batch_output = self._forward( + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask = batch_masks, + input_concat_cond=batch_input_concat_cond, + global_embed = batch_global_cond, + prepend_cond = batch_prepend_cond, + prepend_cond_mask = batch_prepend_cond_mask, + return_info = return_info, + **kwargs) + + if return_info: + batch_output, info = batch_output + + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + # CFG Rescale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + else: + output = cfg_output + + if return_info: + return output, info + + return output + + else: + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + mask=mask, + return_info=return_info, + **kwargs + ) \ No newline at end of file diff --git a/stable_audio_tools/models/factory.py b/stable_audio_tools/models/factory.py new file mode 100644 index 0000000..4188703 --- /dev/null +++ b/stable_audio_tools/models/factory.py @@ -0,0 +1,153 @@ +import json + +def create_model_from_config(model_config): + model_type = model_config.get('model_type', None) + + assert model_type is not None, 'model_type must be specified in model config' + + if model_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + return create_autoencoder_from_config(model_config) + elif model_type == 'diffusion_uncond': + from .diffusion import create_diffusion_uncond_from_config + return create_diffusion_uncond_from_config(model_config) + elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": + from .diffusion import create_diffusion_cond_from_config + return create_diffusion_cond_from_config(model_config) + elif model_type == 'diffusion_autoencoder': + from .autoencoders import create_diffAE_from_config + return create_diffAE_from_config(model_config) + elif model_type == 'lm': + from .lm import create_audio_lm_from_config + return create_audio_lm_from_config(model_config) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_model_from_config_path(model_config_path): + with open(model_config_path) as f: + model_config = json.load(f) + + return create_model_from_config(model_config) + +def create_pretransform_from_config(pretransform_config, sample_rate): + pretransform_type = pretransform_config.get('type', None) + + assert pretransform_type is not None, 'type must be specified in pretransform config' + + if pretransform_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + from .pretransforms import AutoencoderPretransform + + # Create fake top-level config to pass sample rate to autoencoder constructor + # This is a bit of a hack but it keeps us from re-defining the sample rate in the config + autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} + autoencoder = create_autoencoder_from_config(autoencoder_config) + + scale = pretransform_config.get("scale", 1.0) + model_half = pretransform_config.get("model_half", False) + iterate_batch = pretransform_config.get("iterate_batch", False) + chunked = pretransform_config.get("chunked", False) + + pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) + elif pretransform_type == 'wavelet': + from .pretransforms import WaveletPretransform + + wavelet_config = pretransform_config["config"] + channels = wavelet_config["channels"] + levels = wavelet_config["levels"] + wavelet = wavelet_config["wavelet"] + + pretransform = WaveletPretransform(channels, levels, wavelet) + elif pretransform_type == 'pqmf': + from .pretransforms import PQMFPretransform + pqmf_config = pretransform_config["config"] + pretransform = PQMFPretransform(**pqmf_config) + elif pretransform_type == 'dac_pretrained': + from .pretransforms import PretrainedDACPretransform + pretrained_dac_config = pretransform_config["config"] + pretransform = PretrainedDACPretransform(**pretrained_dac_config) + elif pretransform_type == "audiocraft_pretrained": + from .pretransforms import AudiocraftCompressionPretransform + + audiocraft_config = pretransform_config["config"] + pretransform = AudiocraftCompressionPretransform(**audiocraft_config) + else: + raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') + + enable_grad = pretransform_config.get('enable_grad', False) + pretransform.enable_grad = enable_grad + + pretransform.eval().requires_grad_(pretransform.enable_grad) + + return pretransform + +def create_bottleneck_from_config(bottleneck_config): + bottleneck_type = bottleneck_config.get('type', None) + + assert bottleneck_type is not None, 'type must be specified in bottleneck config' + + if bottleneck_type == 'tanh': + from .bottleneck import TanhBottleneck + bottleneck = TanhBottleneck() + elif bottleneck_type == 'vae': + from .bottleneck import VAEBottleneck + bottleneck = VAEBottleneck() + elif bottleneck_type == 'rvq': + from .bottleneck import RVQBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQBottleneck(**quantizer_params) + elif bottleneck_type == "dac_rvq": + from .bottleneck import DACRVQBottleneck + + bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) + + elif bottleneck_type == 'rvq_vae': + from .bottleneck import RVQVAEBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQVAEBottleneck(**quantizer_params) + + elif bottleneck_type == 'dac_rvq_vae': + from .bottleneck import DACRVQVAEBottleneck + bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) + elif bottleneck_type == 'l2_norm': + from .bottleneck import L2Bottleneck + bottleneck = L2Bottleneck() + elif bottleneck_type == "wasserstein": + from .bottleneck import WassersteinBottleneck + bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) + elif bottleneck_type == "fsq": + from .bottleneck import FSQBottleneck + bottleneck = FSQBottleneck(**bottleneck_config["config"]) + else: + raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') + + requires_grad = bottleneck_config.get('requires_grad', True) + if not requires_grad: + for param in bottleneck.parameters(): + param.requires_grad = False + + return bottleneck diff --git a/stable_audio_tools/models/lm.py b/stable_audio_tools/models/lm.py new file mode 100644 index 0000000..f7e216f --- /dev/null +++ b/stable_audio_tools/models/lm.py @@ -0,0 +1,542 @@ +from dataclasses import dataclass +import torch +from tqdm.auto import trange +import typing as tp +from einops import rearrange +from torch import nn + +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .factory import create_pretransform_from_config +from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone +from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform +from .utils import multinomial, sample_top_k, sample_top_p + +from .codebook_patterns import ( + CodebooksPatternProvider, + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider +) + +# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +@dataclass +class LMOutput: + # The logits are already re-aligned with the input codes + # hence no extra shift is required, e.g. when computing CE + logits: torch.Tensor # [B, K, T, card] + mask: torch.Tensor # [B, K, T] + +# Wrapper for a multi-codebook language model +# Handles patterns and quantizer heads +class AudioLanguageModel(nn.Module): + def __init__( + self, + pattern_provider: CodebooksPatternProvider, + backbone: AudioLMBackbone, + num_quantizers: int, + codebook_size: int + ): + super().__init__() + + self.pattern_provider = pattern_provider + self.backbone = backbone + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + + self.masked_token_id = codebook_size + + # Per-quantizer embedders + # Add one for the mask embed + self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)]) + + # Per-quantizer output heads + self.quantizer_heads = nn.ModuleList([ + nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers) + ]) + + def forward(self, + sequence: torch.Tensor, #[batch, seq_len, + prepend_cond=None, #[batch, seq, channels] + prepend_cond_mask=None, + cross_attn_cond=None, #[batch, seq, channels], + **kwargs + ): + + + batch, num_quantizers, seq_len = sequence.shape + + assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" + + backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim] + + dtype = next(self.parameters()).dtype + + if cross_attn_cond is not None: + cross_attn_cond = cross_attn_cond.to(dtype) + + if prepend_cond is not None: + prepend_cond = prepend_cond.to(dtype) + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.to(dtype) + + backbone_input = backbone_input.to(dtype) + + output = self.backbone( + backbone_input, + cross_attn_cond=cross_attn_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + **kwargs + ) # [batch, seq_len, embed_dim] + + # Run output through quantizer heads + logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size] + + return logits + + def compute_logits( + self, + codes, #[batch, num_quantizers, seq_len] + **kwargs): + """ + Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning + Handles translation between input sequence and pattern-shifted sequence + Only used during training + """ + + batch, _, seq_len = codes.shape + + pattern = self.pattern_provider.get_pattern(seq_len) + + # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps + shifted_codes, _, _ = pattern.build_pattern_sequence( + codes, + self.masked_token_id, + keep_only_valid_steps=True + ) + + # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size] + logits = self(shifted_codes, **kwargs) + + # Rearrange logits to prepare to revert pattern + logits = rearrange(logits, "b n s c -> b c n s") + + # Revert sequence logits back to original sequence length, removing masked steps + logits, _, logits_mask = pattern.revert_pattern_logits( + logits, float('nan'), keep_only_valid_steps=True + ) + + logits = rearrange(logits, "b c n t -> b n t c") + + logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len] + + return LMOutput(logits=logits, mask=logits_mask) + +# Conditioning and generation wrapper for a multi-codebook language model +# Handles conditioning, CFG, generation, and encoding/decoding +class AudioLanguageModelWrapper(nn.Module): + def __init__( + self, + pretransform: Pretransform, + lm: AudioLanguageModel, + sample_rate: int, + min_input_length: int, + conditioner: MultiConditioner = None, + cross_attn_cond_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [] + ): + super().__init__() + + assert pretransform.is_discrete, "Pretransform must be discrete" + self.pretransform = pretransform + + self.pretransform.requires_grad_(False) + self.pretransform.eval() + + if isinstance(self.pretransform, AutoencoderPretransform): + self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers + self.codebook_size = self.pretransform.model.bottleneck.codebook_size + elif isinstance(self.pretransform, PretrainedDACPretransform): + self.num_quantizers = self.pretransform.model.num_quantizers + self.codebook_size = self.pretransform.model.codebook_size + elif isinstance(self.pretransform, AudiocraftCompressionPretransform): + self.num_quantizers = self.pretransform.num_quantizers + self.codebook_size = self.pretransform.codebook_size + else: + raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") + + self.conditioner = conditioner + + self.lm = lm + + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.cross_attn_cond_ids = cross_attn_cond_ids + self.prepend_cond_ids = prepend_cond_ids + self.global_cond_ids = global_cond_ids + + def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + prepend_cond = None + prepend_cond_mask = None + global_cond = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) + prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_prepend_cond": prepend_cond, + "negative_prepend_cond_mask": prepend_cond_mask, + "negative_global_cond": global_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask, + "global_cond": global_cond + } + + def compute_logits( + self, + codes, + condition_tensors=None, + cfg_dropout_prob=0.0, + **kwargs + ): + """ + Compute logits for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG dropout + """ + + if condition_tensors is None: + condition_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(condition_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if global_cond is not None: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) + global_cond = torch.where(dropout_mask, null_embed, global_cond) + + return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + def _sample_next_token( + self, + sequence, #[batch, num_quantizers, seq_len] + conditioning_tensors=None, + cross_attn_use_cfg=True, + prepend_use_cfg=True, + global_use_cfg=True, + cfg_scale=1.0, + top_k=250, + top_p=0.0, + temp=1.0, + **kwargs + ): + """ + Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG inference + """ + + if conditioning_tensors is None: + conditioning_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_scale != 1.0: + + # Batch size is doubled to account for negative samples + sequence = torch.cat([sequence, sequence], dim=0) + + if cross_attn_cond is not None and cross_attn_use_cfg: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if prepend_cond is not None and prepend_use_cfg: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + if global_cond is not None and global_use_cfg: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + + global_cond = torch.cat([global_cond, null_embed], dim=0) + + logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + if cfg_scale != 1.0: + cond_logits, uncond_logits = logits.chunk(2, dim=0) + + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + + logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] + + # Grab the logits for the last step + logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] + + # Apply top-k or top-p sampling + + if temp > 0: + probs = torch.softmax(logits / temp, dim=-1) + + if top_p > 0.0: + next_token = sample_top_p(probs, p=top_p) + elif top_k > 0: + next_token = sample_top_k(probs, k=top_k) + else: + next_token = multinomial(probs, num_samples=1) + + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] + + return next_token + + @torch.no_grad() + def generate( + self, + max_gen_len: int = 256, + batch_size: tp.Optional[int] = None, + init_data: tp.Optional[torch.Tensor] = None, + conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + use_cache: bool = True, + cfg_scale: float = 1.0, + **kwargs + ): + device = next(self.parameters()).device + + if conditioning_tensors is None and conditioning is not None: + # Convert conditioning inputs to conditioning tensors + conditioning_tensors = self.conditioner(conditioning, device) + + # Check that batch size is consistent across inputs + possible_batch_sizes = [] + + if batch_size is not None: + possible_batch_sizes.append(batch_size) + elif init_data is not None: + possible_batch_sizes.append(init_data.shape[0]) + elif conditioning_tensors is not None: + # Assume that the first conditioning tensor has the batch dimension + possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) + else: + possible_batch_sizes.append(1) + + assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" + + batch_size = possible_batch_sizes[0] + + if init_data is None: + # Initialize with zeros + assert batch_size > 0 + init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) + + batch_size, num_quantizers, seq_len = init_data.shape + + start_offset = seq_len + assert start_offset < max_gen_len, "init data longer than max gen length" + + pattern = self.lm.pattern_provider.get_pattern(max_gen_len) + + unknown_token = -1 + + # Initialize the generated codes with the init data, padded with unknown tokens + gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) + gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] + + gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] + + start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) + assert start_offset_sequence is not None + + # Generation + prev_offset = 0 + gen_sequence_len = gen_sequence.shape[-1] + + # Reset generation cache + if use_cache and self.lm.backbone.use_generation_cache: + self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) + + for offset in trange(start_offset_sequence, gen_sequence_len): + + # Get the full sequence up to the current offset + curr_sequence = gen_sequence[..., prev_offset:offset] + + next_token = self._sample_next_token( + curr_sequence, + conditioning_tensors=conditioning_tensors, + use_cache=use_cache, + cfg_scale=cfg_scale, + **kwargs + ) + + valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) + next_token[~valid_mask] = self.lm.masked_token_id + + # Update the generated sequence with the next token + gen_sequence[..., offset:offset+1] = torch.where( + gen_sequence[..., offset:offset+1] == unknown_token, + next_token, + gen_sequence[..., offset:offset+1] + ) + + if use_cache and self.lm.backbone.use_generation_cache: + # Only update the offset if caching is being used + prev_offset = offset + + self.lm.backbone.update_generation_cache(offset) + + if callback is not None: + # Callback to report progress + # Pass in the offset relative to the start of the sequence, and the length of the current sequence + callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) + + assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" + + out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) + + # sanity checks over the returned codes and corresponding masks + assert (out_codes[..., :max_gen_len] != unknown_token).all() + assert (out_mask[..., :max_gen_len] == 1).all() + + #out_codes = out_codes[..., 0:max_gen_len] + + return out_codes + + + def generate_audio( + self, + **kwargs + ): + """ + Generate audio from a batch of codes + """ + + codes = self.generate(**kwargs) + + audio = self.pretransform.decode_tokens(codes) + + return audio + + +def create_audio_lm_from_config(config): + model_config = config.get('model', None) + assert model_config is not None, 'model config must be specified in config' + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + lm_config = model_config.get('lm', None) + assert lm_config is not None, 'lm config must be specified in model config' + + codebook_pattern = lm_config.get("codebook_pattern", "delay") + + pattern_providers = { + 'parallel': ParallelPatternProvider, + 'delay': DelayedPatternProvider, + 'unroll': UnrolledPatternProvider, + 'musiclm': MusicLMPattern, + } + + pretransform_config = model_config.get("pretransform", None) + + pretransform = create_pretransform_from_config(pretransform_config, sample_rate) + + assert pretransform.is_discrete, "Pretransform must be discrete" + + min_input_length = pretransform.downsampling_ratio + + pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers) + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) + prepend_cond_ids = lm_config.get('prepend_cond_ids', []) + global_cond_ids = lm_config.get('global_cond_ids', []) + + lm_type = lm_config.get("type", None) + lm_model_config = lm_config.get("config", None) + + assert lm_type is not None, "Must specify lm type in lm config" + assert lm_model_config is not None, "Must specify lm model config in lm config" + + if lm_type == "x-transformers": + backbone = XTransformersAudioLMBackbone(**lm_model_config) + elif lm_type == "continuous_transformer": + backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) + else: + raise NotImplementedError(f"Unrecognized lm type {lm_type}") + + lm = AudioLanguageModel( + pattern_provider=pattern_provider, + backbone=backbone, + num_quantizers=pretransform.num_quantizers, + codebook_size=pretransform.codebook_size + ) + + model = AudioLanguageModelWrapper( + pretransform=pretransform, + lm=lm, + conditioner=conditioner, + sample_rate=sample_rate, + min_input_length=min_input_length, + cross_attn_cond_ids=cross_attn_cond_ids, + prepend_cond_ids=prepend_cond_ids, + global_cond_ids=global_cond_ids + ) + + return model \ No newline at end of file diff --git a/stable_audio_tools/models/local_attention.py b/stable_audio_tools/models/local_attention.py new file mode 100644 index 0000000..893ce11 --- /dev/null +++ b/stable_audio_tools/models/local_attention.py @@ -0,0 +1,278 @@ +import torch + +from einops import rearrange +from torch import nn + +from .blocks import AdaRMSNorm +from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py +class ContinuousLocalTransformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + dim_in = None, + dim_out = None, + causal = False, + local_attn_window_size = 64, + heads = 8, + ff_mult = 2, + cond_dim = 0, + cross_attn_cond_dim = 0, + **kwargs + ): + super().__init__() + + dim_head = dim//heads + + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() + + self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() + + self.local_attn_window_size = local_attn_window_size + + self.cond_dim = cond_dim + + self.cross_attn_cond_dim = cross_attn_cond_dim + + self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) + + for _ in range(depth): + + self.layers.append(nn.ModuleList([ + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + Attention( + dim=dim, + dim_heads=dim_head, + causal=causal, + zero_init_output=True, + natten_kernel_size=local_attn_window_size, + ), + Attention( + dim=dim, + dim_heads=dim_head, + dim_context = cross_attn_cond_dim, + zero_init_output=True + ) if self.cross_attn_cond_dim > 0 else nn.Identity(), + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + FeedForward(dim = dim, mult = ff_mult, no_bias=True) + ])) + + def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): + + x = checkpoint(self.project_in, x) + + if prepend_cond is not None: + x = torch.cat([prepend_cond, x], dim=1) + + pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + + for attn_norm, attn, xattn, ff_norm, ff in self.layers: + + residual = x + if cond is not None: + x = checkpoint(attn_norm, x, cond) + else: + x = checkpoint(attn_norm, x) + + x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual + + if cross_attn_cond is not None: + x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x + + residual = x + + if cond is not None: + x = checkpoint(ff_norm, x, cond) + else: + x = checkpoint(ff_norm, x) + + x = checkpoint(ff, x) + residual + + return checkpoint(self.project_out, x) + +class TransformerDownsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim = 768, + depth = 3, + heads = 12, + downsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.downsample_ratio = downsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size=local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) + + + def forward(self, x): + + x = checkpoint(self.project_in, x) + + # Compute + x = self.transformer(x) + + # Trade sequence length for channels + x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) + + # Project back to embed dim + x = checkpoint(self.project_down, x) + + return x + +class TransformerUpsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim, + depth = 3, + heads = 12, + upsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.upsample_ratio = upsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size = local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) + + def forward(self, x): + + # Project to embed dim + x = checkpoint(self.project_in, x) + + # Project to increase channel dim + x = checkpoint(self.project_up, x) + + # Trade channels for sequence length + x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) + + # Compute + x = self.transformer(x) + + return x + + +class TransformerEncoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [96, 192, 384, 768], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerDownsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + downsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + + return x + + +class TransformerDecoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [768, 384, 192, 96], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerUpsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + upsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + return x \ No newline at end of file diff --git a/stable_audio_tools/models/pqmf.py b/stable_audio_tools/models/pqmf.py new file mode 100644 index 0000000..007fdb5 --- /dev/null +++ b/stable_audio_tools/models/pqmf.py @@ -0,0 +1,393 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from scipy.optimize import fmin +from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord + +class PQMF(nn.Module): + """ + Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. + Uses polyphase representation which is computationally more efficient for real-time. + + Parameters: + - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. + - num_bands (int): Number of desired frequency bands. It must be a power of 2. + """ + + def __init__(self, attenuation, num_bands): + super(PQMF, self).__init__() + + # Ensure num_bands is a power of 2 + is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) + assert is_power_of_2, "'num_bands' must be a power of 2." + + # Create the prototype filter + prototype_filter = design_prototype_filter(attenuation, num_bands) + filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) + padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) + + # Register filters and settings + self.register_buffer("filter_bank", padded_filter_bank) + self.register_buffer("prototype", prototype_filter) + self.num_bands = num_bands + + def forward(self, signal): + """Decompose the signal into multiple frequency bands.""" + # If signal is not a pytorch tensor of Batch x Channels x Length, convert it + signal = prepare_signal_dimensions(signal) + # The signal length must be a multiple of num_bands. Pad it with zeros. + signal = pad_signal(signal, self.num_bands) + # run it + signal = polyphase_analysis(signal, self.filter_bank) + return apply_alias_cancellation(signal) + + def inverse(self, bands): + """Reconstruct the original signal from the frequency bands.""" + bands = apply_alias_cancellation(bands) + return polyphase_synthesis(bands, self.filter_bank) + + +def prepare_signal_dimensions(signal): + """ + Rearrange signal into Batch x Channels x Length. + + Parameters + ---------- + signal : torch.Tensor or numpy.ndarray + The input signal. + + Returns + ------- + torch.Tensor + Preprocessed signal tensor. + """ + # Convert numpy to torch tensor + if isinstance(signal, np.ndarray): + signal = torch.from_numpy(signal) + + # Ensure tensor + if not isinstance(signal, torch.Tensor): + raise ValueError("Input should be either a numpy array or a PyTorch tensor.") + + # Modify dimension of signal to Batch x Channels x Length + if signal.dim() == 1: + # This is just a mono signal. Unsqueeze to 1 x 1 x Length + signal = signal.unsqueeze(0).unsqueeze(0) + elif signal.dim() == 2: + # This is a multi-channel signal (e.g. stereo) + # Rearrange so that larger dimension (Length) is last + if signal.shape[0] > signal.shape[1]: + signal = signal.T + # Unsqueeze to 1 x Channels x Length + signal = signal.unsqueeze(0) + return signal + +def pad_signal(signal, num_bands): + """ + Pads the signal to make its length divisible by the given number of bands. + + Parameters + ---------- + signal : torch.Tensor + The input signal tensor, where the last dimension represents the signal length. + + num_bands : int + The number of bands by which the signal length should be divisible. + + Returns + ------- + torch.Tensor + The padded signal tensor. If the original signal length was already divisible + by num_bands, returns the original signal unchanged. + """ + remainder = signal.shape[-1] % num_bands + if remainder > 0: + padding_size = num_bands - remainder + signal = nn.functional.pad(signal, (0, padding_size)) + return signal + +def generate_modulated_filter_bank(prototype_filter, num_bands): + """ + Generate a QMF bank of cosine modulated filters based on a given prototype filter. + + Parameters + ---------- + prototype_filter : torch.Tensor + The prototype filter used as the basis for modulation. + num_bands : int + The number of desired subbands or filters. + + Returns + ------- + torch.Tensor + A bank of cosine modulated filters. + """ + + # Initialize indices for modulation. + subband_indices = torch.arange(num_bands).reshape(-1, 1) + + # Calculate the length of the prototype filter. + filter_length = prototype_filter.shape[-1] + + # Generate symmetric time indices centered around zero. + time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) + + # Calculate phase offsets to ensure orthogonality between subbands. + phase_offsets = (-1)**subband_indices * np.pi / 4 + + # Compute the cosine modulation function. + modulation = torch.cos( + (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets + ) + + # Apply modulation to the prototype filter. + modulated_filters = 2 * prototype_filter * modulation + + return modulated_filters + + +def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): + """ + Design a lowpass filter using the Kaiser window. + + Parameters + ---------- + angular_cutoff : float + The angular frequency cutoff of the filter. + attenuation : float + The desired stopband attenuation in decibels (dB). + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The designed lowpass filter coefficients. + """ + + estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) + + # Ensure the estimated length is odd. + estimated_length = 2 * (estimated_length // 2) + 1 + + if filter_length is None: + filter_length = estimated_length + + return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) + + +def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): + """ + Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 + + Parameters + ---------- + angular_cutoff : float + Angular frequency cutoff of the filter. + attenuation : float + Desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. + + Returns + ------- + float + The computed objective (loss) value for the given filter specs. + """ + + filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) + convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") + + return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) + + +def design_prototype_filter(attenuation, num_bands, filter_length=None): + """ + Design the optimal prototype filter for a multiband system given the desired specs. + + Parameters + ---------- + attenuation : float + The desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The optimal prototype filter coefficients. + """ + + optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), + 1 / num_bands, disp=0)[0] + + prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) + return torch.tensor(prototype_filter, dtype=torch.float32) + +def pad_to_nearest_power_of_two(x): + """ + Pads the input tensor 'x' on both sides such that its last dimension + becomes the nearest larger power of two. + + Parameters: + ----------- + x : torch.Tensor + The input tensor to be padded. + + Returns: + -------- + torch.Tensor + The padded tensor. + """ + current_length = x.shape[-1] + target_length = 2**math.ceil(math.log2(current_length)) + + total_padding = target_length - current_length + left_padding = total_padding // 2 + right_padding = total_padding - left_padding + + return nn.functional.pad(x, (left_padding, right_padding)) + +def apply_alias_cancellation(x): + """ + Applies alias cancellation by inverting the sign of every + second element of every second row, starting from the second + row's first element in a tensor. + + This operation helps ensure that the aliasing introduced in + each band during the decomposition will be counteracted during + the reconstruction. + + Parameters: + ----------- + x : torch.Tensor + The input tensor. + + Returns: + -------- + torch.Tensor + Tensor with specific elements' sign inverted for alias cancellation. + """ + + # Create a mask of the same shape as 'x', initialized with all ones + mask = torch.ones_like(x) + + # Update specific elements in the mask to -1 to perform inversion + mask[..., 1::2, ::2] = -1 + + # Apply the mask to the input tensor 'x' + return x * mask + +def ensure_odd_length(tensor): + """ + Pads the last dimension of a tensor to ensure its size is odd. + + Parameters: + ----------- + tensor : torch.Tensor + Input tensor whose last dimension might need padding. + + Returns: + -------- + torch.Tensor + The original tensor if its last dimension was already odd, + or the padded tensor with an odd-sized last dimension. + """ + + last_dim_size = tensor.shape[-1] + + if last_dim_size % 2 == 0: + tensor = nn.functional.pad(tensor, (0, 1)) + + return tensor + +def polyphase_analysis(signal, filter_bank): + """ + Applies the polyphase method to efficiently analyze the signal using a filter bank. + + Parameters: + ----------- + signal : torch.Tensor + Input signal tensor with shape (Batch x Channels x Length). + + filter_bank : torch.Tensor + Filter bank tensor with shape (Bands x Length). + + Returns: + -------- + torch.Tensor + Signal split into sub-bands. (Batch x Channels x Bands x Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange signal for polyphase processing. + # Also combine Batch x Channel into one dimension for now. + #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) + signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) + + # Rearrange the filter bank for matching signal shape + filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) + + # Apply convolution with appropriate padding to maintain spatial dimensions + padding = filter_bank.shape[-1] // 2 + filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) + + # Truncate the last dimension post-convolution to adjust the output shape + filtered_signal = filtered_signal[..., :-1] + # Rearrange the first dimension back into Batch x Channels + filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) + + return filtered_signal + +def polyphase_synthesis(signal, filter_bank): + """ + Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. + + Parameters + ---------- + signal : torch.Tensor + Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). + + filter_bank : torch.Tensor + Analysis filter bank (shape: Bands x Length). + + should_rearrange : bool, optional + Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. + + Returns + ------- + torch.Tensor + Reconstructed signal (shape: Batch x Channels X Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange the filter bank + filter_bank = filter_bank.flip(-1) + filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) + + # Combine Batch x Channels into one dimension for now. + signal = rearrange(signal, "b c n t -> (b c) n t") + + # Apply convolution with appropriate padding + padding_amount = filter_bank.shape[-1] // 2 + 1 + reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) + + # Scale the result + reconstructed_signal = reconstructed_signal[..., :-1] * num_bands + + # Reorganize the output and truncate + reconstructed_signal = reconstructed_signal.flip(1) + reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) + reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] + + return reconstructed_signal \ No newline at end of file diff --git a/stable_audio_tools/models/pretrained.py b/stable_audio_tools/models/pretrained.py new file mode 100644 index 0000000..67b34fb --- /dev/null +++ b/stable_audio_tools/models/pretrained.py @@ -0,0 +1,25 @@ +import json + +from .factory import create_model_from_config +from .utils import load_ckpt_state_dict + +from huggingface_hub import hf_hub_download + +def get_pretrained_model(name: str): + + model_config_path = hf_hub_download(name, filename="config.json", repo_type='model') + + with open(model_config_path) as f: + model_config = json.load(f) + + model = create_model_from_config(model_config) + + # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file + try: + model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') + except Exception as e: + model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') + + model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + + return model, model_config \ No newline at end of file diff --git a/stable_audio_tools/models/pretransforms.py b/stable_audio_tools/models/pretransforms.py new file mode 100644 index 0000000..c9942db --- /dev/null +++ b/stable_audio_tools/models/pretransforms.py @@ -0,0 +1,258 @@ +import torch +from einops import rearrange +from torch import nn + +class Pretransform(nn.Module): + def __init__(self, enable_grad, io_channels, is_discrete): + super().__init__() + + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + + self.enable_grad = enable_grad + + def encode(self, x): + raise NotImplementedError + + def decode(self, z): + raise NotImplementedError + + def tokenize(self, x): + raise NotImplementedError + + def decode_tokens(self, tokens): + raise NotImplementedError + +class AutoencoderPretransform(Pretransform): + def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): + super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) + self.model = model + self.model.requires_grad_(False).eval() + self.scale=scale + self.downsampling_ratio = model.downsampling_ratio + self.io_channels = model.io_channels + self.sample_rate = model.sample_rate + + self.model_half = model_half + self.iterate_batch = iterate_batch + + self.encoded_channels = model.latent_dim + + self.chunked = chunked + self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None + self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + + if self.model_half: + self.model.half() + + def encode(self, x, **kwargs): + + if self.model_half: + x = x.half() + self.model.to(torch.float16) + + encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + encoded = encoded.float() + + return encoded / self.scale + + def decode(self, z, **kwargs): + z = z * self.scale + + if self.model_half: + z = z.half() + self.model.to(torch.float16) + + decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + decoded = decoded.float() + + return decoded + + def tokenize(self, x, **kwargs): + assert self.model.is_discrete, "Cannot tokenize with a continuous model" + + _, info = self.model.encode(x, return_info = True, **kwargs) + + return info[self.model.bottleneck.tokens_id] + + def decode_tokens(self, tokens, **kwargs): + assert self.model.is_discrete, "Cannot decode tokens with a continuous model" + + return self.model.decode_tokens(tokens, **kwargs) + + def load_state_dict(self, state_dict, strict=True): + self.model.load_state_dict(state_dict, strict=strict) + +class WaveletPretransform(Pretransform): + def __init__(self, channels, levels, wavelet): + super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) + + from .wavelets import WaveletEncode1d, WaveletDecode1d + + self.encoder = WaveletEncode1d(channels, levels, wavelet) + self.decoder = WaveletDecode1d(channels, levels, wavelet) + + self.downsampling_ratio = 2 ** levels + self.io_channels = channels + self.encoded_channels = channels * self.downsampling_ratio + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + return self.decoder(z) + +class PQMFPretransform(Pretransform): + def __init__(self, attenuation=100, num_bands=16): + # TODO: Fix PQMF to take in in-channels + super().__init__(enable_grad=False, io_channels=1, is_discrete=False) + from .pqmf import PQMF + self.pqmf = PQMF(attenuation, num_bands) + + + def encode(self, x): + # x is (Batch x Channels x Time) + x = self.pqmf.forward(x) + # pqmf.forward returns (Batch x Channels x Bands x Time) + # but Pretransform needs Batch x Channels x Time + # so concatenate channels and bands into one axis + return rearrange(x, "b c n t -> b (c n) t") + + def decode(self, x): + # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) + x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) + # returns (Batch x Channels x Time) + return self.pqmf.inverse(x) + +class PretrainedDACPretransform(Pretransform): + def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + import dac + + model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) + + self.model = dac.DAC.load(model_path) + + self.quantize_on_decode = quantize_on_decode + + if model_type == "44khz": + self.downsampling_ratio = 512 + else: + self.downsampling_ratio = 320 + + self.io_channels = 1 + + self.scale = scale + + self.chunked = chunked + + self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.n_codebooks + + self.codebook_size = self.model.codebook_size + + def encode(self, x): + + latents = self.model.encoder(x) + + if self.quantize_on_decode: + output = latents + else: + z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + output = z + + if self.scale != 1.0: + output = output / self.scale + + return output + + def decode(self, z): + + if self.scale != 1.0: + z = z * self.scale + + if self.quantize_on_decode: + z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + return self.model.decode(z) + + def tokenize(self, x): + return self.model.encode(x)[1] + + def decode_tokens(self, tokens): + latents = self.model.quantizer.from_codes(tokens) + return self.model.decode(latents) + +class AudiocraftCompressionPretransform(Pretransform): + def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + try: + from audiocraft.models import CompressionModel + except ImportError: + raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") + + self.model = CompressionModel.get_pretrained(model_type) + + self.quantize_on_decode = quantize_on_decode + + self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) + + self.sample_rate = self.model.sample_rate + + self.io_channels = self.model.channels + + self.scale = scale + + #self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.num_codebooks + + self.codebook_size = self.model.cardinality + + self.model.to(torch.float16).eval().requires_grad_(False) + + def encode(self, x): + + assert False, "Audiocraft compression models do not support continuous encoding" + + # latents = self.model.encoder(x) + + # if self.quantize_on_decode: + # output = latents + # else: + # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + # output = z + + # if self.scale != 1.0: + # output = output / self.scale + + # return output + + def decode(self, z): + + assert False, "Audiocraft compression models do not support continuous decoding" + + # if self.scale != 1.0: + # z = z * self.scale + + # if self.quantize_on_decode: + # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + # return self.model.decode(z) + + def tokenize(self, x): + with torch.cuda.amp.autocast(enabled=False): + return self.model.encode(x.to(torch.float16))[0] + + def decode_tokens(self, tokens): + with torch.cuda.amp.autocast(enabled=False): + return self.model.decode(tokens) diff --git a/stable_audio_tools/models/temptransformer.py b/stable_audio_tools/models/temptransformer.py new file mode 100644 index 0000000..40cf3d2 --- /dev/null +++ b/stable_audio_tools/models/temptransformer.py @@ -0,0 +1,190 @@ +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + +class SA_PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class SA_FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class SA_Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = dots.softmax(dim=-1) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + + +class ReAttention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head ** -0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.reattn_weights = nn.Parameter(torch.randn(heads, heads)) + + self.reattn_norm = nn.Sequential( + Rearrange('b h i j -> b i j h'), + nn.LayerNorm(heads), + Rearrange('b i j h -> b h i j') + ) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + # attention + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + attn = dots.softmax(dim=-1) + + # re-attention + + attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights) + attn = self.reattn_norm(attn) + + # aggregate and out + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + +class LeFF(nn.Module): + + def __init__(self, dim = 192, scale = 4, depth_kernel = 3): + super().__init__() + + scale_dim = dim*scale + self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim), + Rearrange('b n c -> b c n'), + nn.BatchNorm1d(scale_dim), + nn.GELU(), + Rearrange('b c (h w) -> b c h w', h=14, w=14) + ) + + self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False), + nn.BatchNorm2d(scale_dim), + nn.GELU(), + Rearrange('b c h w -> b (h w) c', h=14, w=14) + ) + + self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim), + Rearrange('b n c -> b c n'), + nn.BatchNorm1d(dim), + nn.GELU(), + Rearrange('b c n -> b n c') + ) + + def forward(self, x): + x = self.up_proj(x) + x = self.depth_conv(x) + x = self.down_proj(x) + return x + + +class LCAttention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = dots.softmax(dim=-1) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + +class SA_Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + self.norm = nn.LayerNorm(dim) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + SA_PreNorm(dim, SA_Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + SA_PreNorm(dim, SA_FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return self.norm(x) \ No newline at end of file diff --git a/stable_audio_tools/models/transformer.py b/stable_audio_tools/models/transformer.py new file mode 100644 index 0000000..d5b037e --- /dev/null +++ b/stable_audio_tools/models/transformer.py @@ -0,0 +1,812 @@ +from functools import reduce, partial +from packaging import version + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import torch +import torch.nn.functional as F +from torch import nn, einsum +from torch.cuda.amp import autocast +from typing import Callable, Literal +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + +try: + from flash_attn import flash_attn_func, flash_attn_kvpacked_func +except ImportError as e: + print(e) + print('flash_attn not installed, disabling Flash Attention') + flash_attn_kvpacked_func = None + flash_attn_func = None + +try: + import natten +except ImportError: + natten = None + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + + +# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License +# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + +def create_causal_mask(i, j, device): + return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + +def or_reduce(masks): + head, *body = masks + for rest in body: + head = head | rest + return head + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent = False) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + return emb * self.scale + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos = False, + scale_base = 512, + interpolation_factor = 1., + base = 10000, + base_rescale_factor = 1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len): + device = self.inv_freq.device + + t = torch.arange(seq_len, device = device) + return self.forward(t) + + @autocast(enabled = False) + def forward(self, t): + device = self.inv_freq.device + + t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + + if self.scale is None: + return freqs, 1. + + power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +@autocast(enabled = False) +def apply_rotary_pos_emb(t, freqs, scale = 1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim = -1) + +# norms +class LayerNorm(nn.Module): + def __init__(self, dim, bias=False, fix_scale=False): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + if fix_scale: + self.register_buffer("gamma", torch.ones(dim)) + else: + self.gamma = nn.Parameter(torch.ones(dim)) + + if bias: + self.beta = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("beta", torch.zeros(dim)) + + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta) + +# feedforward + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation: Callable, + use_conv = False, + conv_kernel_size = 3, + ): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim = -1) + return x * self.act(gate) + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out = None, + mult = 4, + no_bias = False, + glu = True, + use_conv = False, + conv_kernel_size = 3, + zero_init_output = True, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) + + # init last linear layer to 0 + if zero_init_output: + nn.init.zeros_(linear_out.weight) + if not no_bias: + nn.init.zeros_(linear_out.bias) + + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + def forward(self, x): + return self.ff(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + dim_context = None, + causal = False, + zero_init_output=True, + qk_norm: Literal['l2', 'ln', 'none'] = 'none', + natten_kernel_size = None + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.causal = causal + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) + else: + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Linear(dim, dim, bias=False) + + if zero_init_output: + nn.init.zeros_(self.to_out.weight) + + self.qk_norm = qk_norm + + if self.qk_norm == "ln": + self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + + # Using 1d neighborhood attention + self.natten_kernel_size = natten_kernel_size + if natten_kernel_size is not None: + return + + self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None + + self.sdp_kwargs = dict( + enable_flash = True, + enable_math = True, + enable_mem_efficient = True + ) + + def flash_attn( + self, + q, + k, + v, + mask = None, + causal = None + ): + batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device + kv_heads = k.shape[1] + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if heads != kv_heads: + # Repeat interleave kv_heads to match q_heads + heads_per_kv_head = heads // kv_heads + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + if k.ndim == 3: + k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + + if v.ndim == 3: + v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + + causal = self.causal if causal is None else causal + + if q_len == 1 and causal: + causal = False + + if mask is not None: + assert mask.ndim == 4 + mask = mask.expand(batch, heads, q_len, k_len) + + # handle kv cache - this should be bypassable in updated flash attention 2 + + if k_len > q_len and causal: + causal_mask = self.create_causal_mask(q_len, k_len, device = device) + if mask is None: + mask = ~causal_mask + else: + mask = mask & ~causal_mask + causal = False + + # manually handle causal mask, if another mask was given + + row_is_entirely_masked = None + + if mask is not None and causal: + causal_mask = self.create_causal_mask(q_len, k_len, device = device) + mask = mask & ~causal_mask + + # protect against an entire row being masked out + + row_is_entirely_masked = ~mask.any(dim = -1) + mask[..., 0] = mask[..., 0] | row_is_entirely_masked + + causal = False + + with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + is_causal = causal + ) + + # for a row that is entirely masked out, should zero out the output of that row token + + if row_is_entirely_masked is not None: + out = out.masked_fill(row_is_entirely_masked[..., None], 0.) + + return out + + def forward( + self, + x, + context = None, + mask = None, + context_mask = None, + rotary_pos_emb = None, + causal = None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) # [B, 24, 1025, 64] + + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + else: + # Use fused linear projection + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm == "l2": + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + elif self.qk_norm == "ln": + q = self.q_norm(q) + k = self.k_norm(k) + + if rotary_pos_emb is not None and not has_context: + freqs, _ = rotary_pos_emb + + q_dtype = q.dtype + k_dtype = k.dtype + + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + + q = apply_rotary_pos_emb(q, freqs) + k = apply_rotary_pos_emb(k, freqs) + + q = q.to(q_dtype) + k = k.to(k_dtype) + + input_mask = context_mask + + if input_mask is None and not has_context: + input_mask = mask + + # determine masking + masks = [] + final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account + + if input_mask is not None: + input_mask = rearrange(input_mask, 'b j -> b 1 1 j') + masks.append(~input_mask) + + # Other masks will be added here later + + if len(masks) > 0: + final_attn_mask = ~or_reduce(masks) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if self.natten_kernel_size is not None: + if natten is None: + raise ImportError('natten not installed, please install natten to use neighborhood attention') + + dtype_in = q.dtype + q, k, v = map(lambda t: t.to(torch.float32), (q, k, v)) + + attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1) + + if final_attn_mask is not None: + attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max) + + attn = F.softmax(attn, dim=-1, dtype=torch.float32) + + out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in) + + # Prioritize Flash Attention 2 + elif self.use_fa_flash: + assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2' + # Flash Attention 2 requires FP16 inputs + fa_dtype_in = q.dtype + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v)) + + out = flash_attn_func(q, k, v, causal = causal) + + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') + + # Fall back to PyTorch implementation + elif self.use_pt_flash: + out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask) + + else: + # Fall back to custom implementation + + if h != kv_h: + # Repeat interleave kv_heads to match q_heads + heads_per_kv_head = h // kv_h + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + scale = 1. / (q.shape[-1] ** 0.5) + + kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' + + dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale + + i, j, dtype = *dots.shape[-2:], dots.dtype + + mask_value = -torch.finfo(dots.dtype).max + + if final_attn_mask is not None: + dots = dots.masked_fill(~final_attn_mask, mask_value) + + if causal: + causal_mask = self.create_causal_mask(i, j, device = device) + dots = dots.masked_fill(causal_mask, mask_value) + + attn = F.softmax(dots, dim=-1, dtype=torch.float32) + attn = attn.type(dtype) + + out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') + + # Communicate between heads + out = self.to_out(out) + + if mask is not None: + mask = rearrange(mask, 'b n -> b n 1') + out = out.masked_fill(~mask, 0.) + + return out + + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs = {}, + ): + + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + cross_attend = False, + dim_context = None, + global_cond_dim = None, + causal = False, + zero_init_branch_outputs = True, + conformer = False, + layer_ix = -1, + remove_norms = False, + attn_kwargs = {}, + ff_kwargs = {}, + norm_kwargs = {} + ): + + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + + self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + + self.self_attn = Attention( + dim, + dim_heads = dim_heads, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.cross_attn = Attention( + dim, + dim_heads = dim_heads, + dim_context=dim_context, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) + + self.layer_ix = layer_ix + + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None + + self.global_cond_dim = global_cond_dim + + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Sequential( + nn.SiLU(), + nn.Linear(global_cond_dim, dim * 6, bias=False) + ) + + nn.init.zeros_(self.to_scale_shift_gate[1].weight) + + def forward( + self, + x, + context = None, + global_cond=None, + mask = None, + context_mask = None, + rotary_pos_emb = None, + adapter=None + ): + + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: # False + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x * torch.sigmoid(1 - gate_self) + x = x + residual + + if context is not None: + + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = x + residual + + else: + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + + if context is not None: + + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + x = x + self.ff(self.ff_norm(x)) + + return x + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in = None, + dim_out = None, + dim_heads = 64, + cross_attend=False, + cond_token_dim=None, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) + + for i in range(depth): + self.layers.append( + TransformerBlock( + dim, + dim_heads = dim_heads, + cross_attend = cross_attend, + dim_context = cond_token_dim, + global_cond_dim = global_cond_dim, + causal = causal, + zero_init_branch_outputs = zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + **kwargs + ) + ) + + def forward( + self, + x, + mask = None, + prepend_embeds = None, + prepend_mask = None, + global_cond = None, + return_info = False, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + + info = { + "hidden_states": [], + } + + x = self.project_in(x) + + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim = -2) + + if prepend_mask is not None or mask is not None: + mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool) + prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool) + + mask = torch.cat((prepend_mask, mask), dim = -1) + + # Attention layers + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb or self.use_abs_pos_emb: + x = x + self.pos_emb(x) + + # Iterate over the transformer layers + for index, layer in enumerate(self.layers): + x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + + if return_info: + info["hidden_states"].append(x) + + x = self.project_out(x) + + if return_info: + return x, info + + return x diff --git a/stable_audio_tools/models/utils.py b/stable_audio_tools/models/utils.py new file mode 100644 index 0000000..a3d92cf --- /dev/null +++ b/stable_audio_tools/models/utils.py @@ -0,0 +1,92 @@ +import torch +from safetensors.torch import load_file + +from torch.nn.utils import remove_weight_norm +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + + +def load_ckpt_state_dict(ckpt_path): + if ckpt_path.endswith(".safetensors"): + state_dict = load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + return state_dict + +def remove_weight_norm_from_model(model): + for module in model.modules(): + if hasattr(module, "weight"): + print(f"Removing weight norm from {module}") + remove_weight_norm(module) + + return model + +# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): + """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. + + Args: + input (torch.Tensor): The input tensor containing probabilities. + num_samples (int): Number of samples to draw. + replacement (bool): Whether to draw with replacement or not. + Keywords args: + generator (torch.Generator): A pseudorandom number generator for sampling. + Returns: + torch.Tensor: Last dimension contains num_samples indices + sampled from the multinomial probability distribution + located in the last dimension of tensor input. + """ + + if num_samples == 1: + q = torch.empty_like(input).exponential_(1, generator=generator) + return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) + + input_ = input.reshape(-1, input.shape[-1]) + output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) + output = output_.reshape(*list(input.shape[:-1]), -1) + return output + + +def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: + """Sample next token from top K values along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + k (int): The k in “top-k”. + Returns: + torch.Tensor: Sampled tokens. + """ + top_k_value, _ = torch.topk(probs, k, dim=-1) + min_value_top_k = top_k_value[..., [-1]] + probs *= (probs >= min_value_top_k).float() + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """Sample next token from top P probabilities along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + p (int): The p in “top-p”. + Returns: + torch.Tensor: Sampled tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort *= (~mask).float() + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +def next_power_of_two(n): + return 2 ** (n - 1).bit_length() + +def next_multiple_of_64(n): + return ((n + 63) // 64) * 64 \ No newline at end of file diff --git a/stable_audio_tools/models/wavelets.py b/stable_audio_tools/models/wavelets.py new file mode 100644 index 0000000..a359e39 --- /dev/null +++ b/stable_audio_tools/models/wavelets.py @@ -0,0 +1,82 @@ +"""The 1D discrete wavelet transform for PyTorch.""" + +from einops import rearrange +import pywt +import torch +from torch import nn +from torch.nn import functional as F +from typing import Literal + + +def get_filter_bank(wavelet): + filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) + if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): + filt = filt[:, 1:] + return filt + +class WaveletEncode1d(nn.Module): + def __init__(self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + super().__init__() + self.wavelet = wavelet + self.channels = channels + self.levels = levels + filt = get_filter_bank(wavelet) + assert filt.shape[-1] % 2 == 1 + kernel = filt[:2, None] + kernel = torch.flip(kernel, dims=(-1,)) + index_i = torch.repeat_interleave(torch.arange(2), channels) + index_j = torch.tile(torch.arange(channels), (2,)) + kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) + kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] + self.register_buffer("kernel", kernel_final) + + def forward(self, x): + for i in range(self.levels): + low, rest = x[:, : self.channels], x[:, self.channels :] + pad = self.kernel.shape[-1] // 2 + low = F.pad(low, (pad, pad), "reflect") + low = F.conv1d(low, self.kernel, stride=2) + rest = rearrange( + rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels + ) + x = torch.cat([low, rest], dim=1) + return x + + +class WaveletDecode1d(nn.Module): + def __init__(self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + super().__init__() + self.wavelet = wavelet + self.channels = channels + self.levels = levels + filt = get_filter_bank(wavelet) + assert filt.shape[-1] % 2 == 1 + kernel = filt[2:, None] + index_i = torch.repeat_interleave(torch.arange(2), channels) + index_j = torch.tile(torch.arange(channels), (2,)) + kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) + kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] + self.register_buffer("kernel", kernel_final) + + def forward(self, x): + for i in range(self.levels): + low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] + pad = self.kernel.shape[-1] // 2 + 2 + low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) + low = F.pad(low, (pad, pad), "reflect") + low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) + low = F.conv_transpose1d( + low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 + ) + low = low[..., pad - 1 : -pad] + rest = rearrange( + rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels + ) + x = torch.cat([low, rest], dim=1) + return x \ No newline at end of file diff --git a/stable_audio_tools/training/__init__.py b/stable_audio_tools/training/__init__.py new file mode 100644 index 0000000..f77486b --- /dev/null +++ b/stable_audio_tools/training/__init__.py @@ -0,0 +1 @@ +from .factory import create_training_wrapper_from_config, create_demo_callback_from_config diff --git a/stable_audio_tools/training/autoencoders.py b/stable_audio_tools/training/autoencoders.py new file mode 100644 index 0000000..91bee39 --- /dev/null +++ b/stable_audio_tools/training/autoencoders.py @@ -0,0 +1,476 @@ +import torch +import torchaudio +import wandb +from einops import rearrange +from safetensors.torch import save_file, save_model +from ema_pytorch import EMA +from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss +import pytorch_lightning as pl +from ..models.autoencoders import AudioAutoencoder +from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss +from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck +from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss +from .utils import create_optimizer_from_config, create_scheduler_from_config + + +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image + +class AutoencoderTrainingWrapper(pl.LightningModule): + def __init__( + self, + autoencoder: AudioAutoencoder, + lr: float = 1e-4, + warmup_steps: int = 0, + encoder_freeze_on_warmup: bool = False, + sample_rate=48000, + loss_config: dict = None, + optimizer_configs: dict = None, + use_ema: bool = True, + ema_copy = None, + force_input_mono = False, + latent_mask_ratio = 0.0, + teacher_model: AudioAutoencoder = None + ): + super().__init__() + + self.automatic_optimization = False + + self.autoencoder = autoencoder + + self.warmed_up = False + self.warmup_steps = warmup_steps + self.encoder_freeze_on_warmup = encoder_freeze_on_warmup + self.lr = lr + + self.force_input_mono = force_input_mono + + self.teacher_model = teacher_model + + if optimizer_configs is None: + optimizer_configs ={ + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + } + + } + + self.optimizer_configs = optimizer_configs + + if loss_config is None: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + loss_config = { + "discriminator": { + "type": "encodec", + "config": { + "n_ffts": scales, + "hop_lengths": hop_sizes, + "win_lengths": win_lengths, + "filters": 32 + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0, + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + }, + "weights": { + "mrstft": 1.0, + } + }, + "time": { + "type": "l1", + "config": {}, + "weights": { + "l1": 0.0, + } + } + } + + self.loss_config = loss_config + + # Spectral reconstruction loss + + stft_loss_args = loss_config['spectral']['config'] + + if self.autoencoder.out_channels == 2: + self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + else: + self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Discriminator + + if loss_config['discriminator']['type'] == 'oobleck': + self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'encodec': + self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'dac': + self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config']) + + self.gen_loss_modules = [] + + # Adversarial and feature matching losses + self.gen_loss_modules += [ + ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'), + ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'), + ] + + if self.teacher_model is not None: + # Distillation losses + + stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25 + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss + AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder + ] + + else: + + # Reconstruction loss + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + ] + + if self.autoencoder.out_channels == 2: + + # Add left and right channel reconstruction losses in addition to the sum and difference + self.gen_loss_modules += [ + AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2), + AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2), + ] + + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + ] + + if self.loss_config['time']['weights']['l1'] > 0.0: + self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss')) + + if self.autoencoder.bottleneck is not None: + self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config) + + self.losses_gen = MultiLoss(self.gen_loss_modules) + + self.disc_loss_modules = [ + ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), + ] + + self.losses_disc = MultiLoss(self.disc_loss_modules) + + # Set up EMA for model weights + self.autoencoder_ema = None + + self.use_ema = use_ema + + if self.use_ema: + self.autoencoder_ema = EMA( + self.autoencoder, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.latent_mask_ratio = latent_mask_ratio + + def configure_optimizers(self): + + opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters()) + opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters()) + + if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']: + sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) + sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc) + return [opt_gen, opt_disc], [sched_gen, sched_disc] + + return [opt_gen, opt_disc] + + def training_step(self, batch, batch_idx): + reals, _ = batch + + # Remove extra dimension added by WebDataset + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if self.global_step >= self.warmup_steps: + self.warmed_up = True + + loss_info = {} + + loss_info["reals"] = reals + + encoder_input = reals + + if self.force_input_mono and encoder_input.shape[1] > 1: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + loss_info["encoder_input"] = encoder_input + + data_std = encoder_input.std() + + if self.warmed_up and self.encoder_freeze_on_warmup: + with torch.no_grad(): + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + else: + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + + loss_info["latents"] = latents + + loss_info.update(encoder_info) + + # Encode with teacher model for distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) + loss_info['teacher_latents'] = teacher_latents + + if self.latent_mask_ratio > 0.0: + mask = torch.rand_like(latents) < self.latent_mask_ratio + latents = torch.where(mask, torch.zeros_like(latents), latents) + + decoded = self.autoencoder.decode(latents) + + loss_info["decoded"] = decoded + + if self.autoencoder.out_channels == 2: + loss_info["decoded_left"] = decoded[:, 0:1, :] + loss_info["decoded_right"] = decoded[:, 1:2, :] + loss_info["reals_left"] = reals[:, 0:1, :] + loss_info["reals_right"] = reals[:, 1:2, :] + + # Distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_decoded = self.teacher_model.decode(teacher_latents) + own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher + teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model + + loss_info['teacher_decoded'] = teacher_decoded + loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded + loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded + + + if self.warmed_up: + loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded) + else: + loss_dis = torch.tensor(0.).to(reals) + loss_adv = torch.tensor(0.).to(reals) + feature_matching_distance = torch.tensor(0.).to(reals) + + loss_info["loss_dis"] = loss_dis + loss_info["loss_adv"] = loss_adv + loss_info["feature_matching_distance"] = feature_matching_distance + + opt_gen, opt_disc = self.optimizers() + + lr_schedulers = self.lr_schedulers() + + sched_gen = None + sched_disc = None + + if lr_schedulers is not None: + sched_gen, sched_disc = lr_schedulers + + # Train the discriminator + if self.global_step % 2 and self.warmed_up: + loss, losses = self.losses_disc(loss_info) + + log_dict = { + 'train/disc_lr': opt_disc.param_groups[0]['lr'] + } + + opt_disc.zero_grad() + self.manual_backward(loss) + opt_disc.step() + + if sched_disc is not None: + # sched step every step + sched_disc.step() + + # Train the generator + else: + + loss, losses = self.losses_gen(loss_info) + + if self.use_ema: + self.autoencoder_ema.update() + + opt_gen.zero_grad() + self.manual_backward(loss) + opt_gen.step() + + if sched_gen is not None: + # scheduler step every step + sched_gen.step() + + log_dict = { + 'train/loss': loss.detach(), + 'train/latent_std': latents.std().detach(), + 'train/data_std': data_std.detach(), + 'train/gen_lr': opt_gen.param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f'train/{loss_name}'] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + + return loss + + def export_model(self, path, use_safetensors=False): + if self.autoencoder_ema is not None: + model = self.autoencoder_ema.ema_model + else: + model = self.autoencoder + + if use_safetensors: + save_model(model, path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AutoencoderDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + module.eval() + + try: + demo_reals, _ = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + encoder_input = demo_reals + + encoder_input = encoder_input.to(module.device) + + if module.force_input_mono: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad(): + if module.use_ema: + + latents = module.autoencoder_ema.ema_model.encode(encoder_input) + + fakes = module.autoencoder_ema.ema_model.decode(latents) + else: + latents = module.autoencoder.encode(encoder_input) + + fakes = module.autoencoder.decode(latents) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) + log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + trainer.logger.experiment.log(log_dict) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + finally: + module.train() + +def create_loss_modules_from_bottleneck(bottleneck, loss_config): + losses = [] + + if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + try: + kl_weight = loss_config['bottleneck']['weights']['kl'] + except: + kl_weight = 1e-6 + + kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') + losses.append(kl_loss) + + if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') + losses.append(quantizer_loss) + + if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): + codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') + commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') + losses.append(codebook_loss) + losses.append(commitment_loss) + + if isinstance(bottleneck, WassersteinBottleneck): + try: + mmd_weight = loss_config['bottleneck']['weights']['mmd'] + except: + mmd_weight = 100 + + mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') + losses.append(mmd_loss) + + return losses \ No newline at end of file diff --git a/stable_audio_tools/training/diffusion.py b/stable_audio_tools/training/diffusion.py new file mode 100644 index 0000000..343ab46 --- /dev/null +++ b/stable_audio_tools/training/diffusion.py @@ -0,0 +1,1656 @@ +import pytorch_lightning as pl +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb + +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +import auraloss +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler +from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper +from ..models.autoencoders import DiffusionAutoencoder +from ..models.diffusion_prior import PriorType +from .autoencoders import create_loss_modules_from_bottleneck +from .losses import AuralossLoss, MSELoss, MultiLoss +from .utils import create_optimizer_from_config, create_scheduler_from_config + +from time import time + + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + +class DiffusionUncondTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training an unconditional audio diffusion model (like Dance Diffusion). + ''' + def __init__( + self, + model: DiffusionModelWrapper, + lr: float = 1e-4, + pre_encoded: bool = False + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(loss_modules) + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals = batch[0] + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + diffusion_input = reals + + loss_info = {} + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + if self.diffusion.pretransform is not None: + if not self.pre_encoded: + with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + loss_info["reals"] = diffusion_input + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + targets = noise * alphas - diffusion_input * sigmas + + with torch.cuda.amp.autocast(): + v = self.diffusion(noised_inputs, t) + + loss_info.update({ + "v": v, + "targets": targets + }) + + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionUncondDemoCallback(pl.Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + demo_steps=250, + sample_rate=48000 + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_samples = module.diffusion.sample_size + + if module.diffusion.pretransform is not None: + demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio + + noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) + + try: + with torch.cuda.amp.autocast(): + fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + del fakes + + except Exception as e: + print(f'{type(e).__name__}: {e}') + finally: + gc.collect() + torch.cuda.empty_cache() + +class DiffusionCondTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training a conditional audio diffusion model. + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = None, + mask_padding: bool = False, + mask_padding_dropout: float = 0.0, + use_ema: bool = True, + log_loss_info: bool = True, + optimizer_configs: dict = None, + pre_encoded: bool = False, + cfg_dropout_prob = 0.1, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + ): + super().__init__() + + self.diffusion = model + + if use_ema: + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + else: + self.diffusion_ema = None + + self.mask_padding = mask_padding + self.mask_padding_dropout = mask_padding_dropout + + self.cfg_dropout_prob = cfg_dropout_prob + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_sampler = timestep_sampler + + self.diffusion_objective = model.diffusion_objective + + if 'av_loss' in optimizer_configs and optimizer_configs['av_loss']['if_add_av_loss']: + av_align_weight = optimizer_configs['av_loss']['config']['weight'] + self.loss_modules = [ + MSELoss("output", + "targets", + weight=1.0 - av_align_weight, + mask_key="padding_mask" if self.mask_padding else None, + name="mse_loss" + ) + ] + else: + self.loss_modules = [ + MSELoss("output", + "targets", + weight=1.0, + mask_key="padding_mask" if self.mask_padding else None, + name="mse_loss" + ) + ] + + + self.losses = MultiLoss(self.loss_modules) + + self.log_loss_info = log_loss_info + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "diffusion": { + "optimizer": { + "type": "Adam", + "config": { + "lr": lr + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + diffusion_opt_config = self.optimizer_configs['diffusion'] + opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + + if "scheduler" in diffusion_opt_config: + sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) + sched_diff_config = { + "scheduler": sched_diff, + "interval": "step" + } + return [opt_diff], [sched_diff_config] + + return [opt_diff] + + def training_step(self, batch, batch_idx): + + + reals, metadata = batch + + p = Profiler() + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + diffusion_input = reals + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + p.tick("setup") + + with torch.cuda.amp.autocast(): + conditioning = self.diffusion.conditioner(metadata, self.device) + + use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout + + # Create batch tensor of attention masks from the "mask" field of the metadata array + if use_padding_mask: + padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) + + p.tick("conditioning") + + if self.diffusion.pretransform is not None: + self.diffusion.pretransform.to(self.device) + + if not self.pre_encoded: + with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) + + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + p.tick("pretransform") + + # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input + if use_padding_mask: + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) # [0.1360, 0.5232] + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + + p.tick("noise") + + extra_args = {} + + if use_padding_mask: + extra_args["mask"] = padding_masks + + with torch.cuda.amp.autocast(): + p.tick("amp") + output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) + p.tick("diffusion") + + loss_info.update({ + "output": output, + "targets": targets, + "padding_mask": padding_masks if use_padding_mask else None, + }) + + loss, losses = self.losses(loss_info) + + p.tick("loss") + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(output, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "b c n -> (b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "b c n -> (b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + #print(f"Profiler: {p}") + return loss + + def validation_step(self, batch, batch_idx): + reals, metadata = batch + + p = Profiler() + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + diffusion_input = reals + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + p.tick("setup") + with torch.cuda.amp.autocast(): + conditioning = self.diffusion.conditioner(metadata, self.device) + + # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding + use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout + + # Create batch tensor of attention masks from the "mask" field of the metadata array + if use_padding_mask: + padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length) + + p.tick("conditioning") + + if self.diffusion.pretransform is not None: + self.diffusion.pretransform.to(self.device) + + if not self.pre_encoded: + with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) + + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + p.tick("pretransform") + + # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input + if use_padding_mask: + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + + p.tick("noise") + + extra_args = {} + + if use_padding_mask: + extra_args["mask"] = padding_masks + + with torch.cuda.amp.autocast(): + p.tick("amp") + + output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) + p.tick("diffusion") + + loss_info.update({ + "output": output, + "targets": targets, + "padding_mask": padding_masks if use_padding_mask else None, + }) + + loss, losses = self.losses(loss_info) + + p.tick("loss") + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(output, targets, reduction="none") + # loss_all = F.binary_cross_entropy_with_logits(output, targets, reduction="none") + + + sigmas = rearrange(self.all_gather(sigmas), "b c n -> (b) c n").squeeze() + # sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "b c n -> (b) c n") + # loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + + log_dict = { + 'valid/loss': loss.detach(), + 'valid/std_data': diffusion_input.std(), + 'valid/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + + for loss_name, loss_value in losses.items(): + log_dict[f"valid/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + # self.log('val_loss', val_loss, on_epoch=True, on_step=True) + + p.tick("log") + #print(f"Profiler: {p}") + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.diffusion_ema is not None: + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + if self.diffusion_ema is not None: + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionCondDemoCallback(pl.Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + demo_steps=250, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {}, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + demo_cond_from_batch: bool = False, + display_audio_cond: bool = False + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + # If true, the callback will use the metadata from the batch to generate the demo conditioning + self.demo_cond_from_batch = demo_cond_from_batch + + # If true, the callback will display the audio conditioning + self.display_audio_cond = display_audio_cond + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_samples = self.demo_samples + + demo_cond = self.demo_conditioning + + if self.demo_cond_from_batch: + # Get metadata from the batch + demo_cond = batch[1][:self.num_demos] + + if module.diffusion.pretransform is not None: + demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio + + noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) + + try: + print("Getting conditioning") + with torch.cuda.amp.autocast(): + conditioning = module.diffusion.conditioner(demo_cond, module.device) + + + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + + log_dict = {} + + if self.display_audio_cond: + audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0) + audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)') + + filename = f'demo_audio_cond_{trainer.global_step:08}.wav' + audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, audio_inputs, self.sample_rate) + log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning") + log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs)) + trainer.logger.experiment.log(log_dict) + + for cfg_scale in self.demo_cfg_scales: + + print(f"Generating demo for cfg scale {cfg_scale}") + + with torch.cuda.amp.autocast(): + model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model + + if module.diffusion_objective == "v": + fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + elif module.diffusion_objective == "rectified_flow": + fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + del fakes + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() + +class DiffusionCondInpaintTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training a conditional audio diffusion model. + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + max_mask_segments = 10, + log_loss_info: bool = False, + optimizer_configs: dict = None, + use_ema: bool = True, + pre_encoded: bool = False, + cfg_dropout_prob = 0.1, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + ): + super().__init__() + + self.diffusion = model + + self.use_ema = use_ema + + if self.use_ema: + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + else: + self.diffusion_ema = None + + self.cfg_dropout_prob = cfg_dropout_prob + + self.lr = lr + self.max_mask_segments = max_mask_segments + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_sampler = timestep_sampler + + self.diffusion_objective = model.diffusion_objective + + self.loss_modules = [ + MSELoss("output", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(self.loss_modules) + + self.log_loss_info = log_loss_info + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "diffusion": { + "optimizer": { + "type": "Adam", + "config": { + "lr": lr + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + diffusion_opt_config = self.optimizer_configs['diffusion'] + opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + + if "scheduler" in diffusion_opt_config: + sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) + sched_diff_config = { + "scheduler": sched_diff, + "interval": "step" + } + return [opt_diff], [sched_diff_config] + + return [opt_diff] + + def random_mask(self, sequence, max_mask_length): + b, _, sequence_length = sequence.size() + + # Create a mask tensor for each batch element + masks = [] + + for i in range(b): + mask_type = random.randint(0, 2) + + if mask_type == 0: # Random mask with multiple segments + num_segments = random.randint(1, self.max_mask_segments) + max_segment_length = max_mask_length // num_segments + + segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) + + mask = torch.ones((1, 1, sequence_length)) + for length in segment_lengths: + mask_start = random.randint(0, sequence_length - length) + mask[:, :, mask_start:mask_start + length] = 0 + + elif mask_type == 1: # Full mask + mask = torch.zeros((1, 1, sequence_length)) + + elif mask_type == 2: # Causal mask + mask = torch.ones((1, 1, sequence_length)) + mask_length = random.randint(1, max_mask_length) + mask[:, :, -mask_length:] = 0 + + mask = mask.to(sequence.device) + masks.append(mask) + + # Concatenate the mask tensors into a single tensor + mask = torch.cat(masks, dim=0).to(sequence.device) + + # Apply the mask to the sequence tensor for each batch element + masked_sequence = sequence * mask + + return masked_sequence, mask + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + p = Profiler() + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + diffusion_input = reals + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + p.tick("setup") + + with torch.cuda.amp.autocast(): + conditioning = self.diffusion.conditioner(metadata, self.device) + + p.tick("conditioning") + + if self.diffusion.pretransform is not None: + self.diffusion.pretransform.to(self.device) + + if not self.pre_encoded: + with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + p.tick("pretransform") + + # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input + # if use_padding_mask: + # padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + # Max mask size is the full sequence length + max_mask_length = diffusion_input.shape[2] + + # Create a mask of random length for a random slice of the input + masked_input, mask = self.random_mask(diffusion_input, max_mask_length) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = [masked_input] + + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + + p.tick("noise") + + extra_args = {} + + with torch.cuda.amp.autocast(): + p.tick("amp") + output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) + p.tick("diffusion") + + loss_info.update({ + "output": output, + "targets": targets, + }) + + loss, losses = self.losses(loss_info) + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(output, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + #print(f"Profiler: {p}") + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.diffusion_ema is not None: + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + if self.diffusion_ema is not None: + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionCondInpaintDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7] + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.demo_cfg_scales = demo_cfg_scales + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + try: + log_dict = {} + + demo_reals, metadata = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + + if not module.pre_encoded: + # Log the real audio + log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu())) + # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals") + + if module.diffusion.pretransform is not None: + module.diffusion.pretransform.to(module.device) + with torch.cuda.amp.autocast(): + demo_reals = module.diffusion.pretransform.encode(demo_reals) + + demo_samples = demo_reals.shape[2] + + # Get conditioning + conditioning = module.diffusion.conditioner(metadata, module.device) + + masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2]) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = [masked_input] + + if module.diffusion.pretransform is not None: + log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu())) + else: + log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu())) + + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + + noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device) + + trainer.logger.experiment.log(log_dict) + + for cfg_scale in self.demo_cfg_scales: + model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model + print(f"Generating demo for cfg scale {cfg_scale}") + + if module.diffusion_objective == "v": + fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + elif module.diffusion_objective == "rectified_flow": + fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + + if module.diffusion.pretransform is not None: + with torch.cuda.amp.autocast(): + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + +class DiffusionAutoencoderTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training a diffusion autoencoder + ''' + def __init__( + self, + model: DiffusionAutoencoder, + lr: float = 1e-4, + ema_copy = None, + use_reconstruction_loss: bool = False + ): + super().__init__() + + self.diffae = model + + self.diffae_ema = EMA( + self.diffae, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + if model.bottleneck is not None: + # TODO: Use loss config for configurable bottleneck weights and reconstruction losses + loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {}) + + self.use_reconstruction_loss = use_reconstruction_loss + + if use_reconstruction_loss: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + sample_rate = model.sample_rate + + stft_loss_args = { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + } + + out_channels = model.out_channels + + if model.pretransform is not None: + out_channels = model.pretransform.io_channels + + if out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + loss_modules.append( + AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + ) + + self.losses = MultiLoss(loss_modules) + + def configure_optimizers(self): + return optim.Adam([*self.diffae.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals = batch[0] + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + loss_info["audio_reals"] = reals + + if self.diffae.pretransform is not None: + with torch.no_grad(): + reals = self.diffae.pretransform.encode(reals) + + loss_info["reals"] = reals + + #Encode reals, skipping the pretransform since it was already applied + latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True) + + loss_info["latents"] = latents + loss_info.update(encoder_info) + + if self.diffae.decoder is not None: + latents = self.diffae.decoder(latents) + + # Upsample latents to match diffusion length + if latents.shape[2] != reals.shape[2]: + latents = F.interpolate(latents, size=reals.shape[2], mode='nearest') + + loss_info["latents_upsampled"] = latents + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(reals) + noised_reals = reals * alphas + noise * sigmas + targets = noise * alphas - reals * sigmas + + with torch.cuda.amp.autocast(): + v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents) + + loss_info.update({ + "v": v, + "targets": targets + }) + + if self.use_reconstruction_loss: + pred = noised_reals * alphas - v * sigmas + + loss_info["pred"] = pred + + if self.diffae.pretransform is not None: + pred = self.diffae.pretransform.decode(pred) + loss_info["audio_pred"] = pred + + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': reals.std(), + 'train/latent_std': latents.std(), + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffae_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.diffae_ema.ema_model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +class DiffusionAutoencoderDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_reals, _ = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + encoder_input = demo_reals + + encoder_input = encoder_input.to(module.device) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad() and torch.cuda.amp.autocast(): + latents = module.diffae_ema.ema_model.encode(encoder_input).float() + fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) + log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + if module.diffae_ema.ema_model.pretransform is not None: + with torch.no_grad() and torch.cuda.amp.autocast(): + initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input) + first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents) + first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)') + first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu() + first_stage_filename = f'first_stage_{trainer.global_step:08}.wav' + torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate) + + log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents)) + + log_dict[f'first_stage'] = wandb.Audio(first_stage_filename, + sample_rate=self.sample_rate, + caption=f'First Stage Reconstructed') + + log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes)) + + + trainer.logger.experiment.log(log_dict) + +def create_source_mixture(reals, num_sources=2): + # Create a fake mixture source by mixing elements from the training batch together with random offsets + source = torch.zeros_like(reals) + for i in range(reals.shape[0]): + sources_added = 0 + + js = list(range(reals.shape[0])) + random.shuffle(js) + for j in js: + if i == j or (i != j and sources_added < num_sources): + # Randomly offset the mixed element between 0 and the length of the source + seq_len = reals.shape[2] + offset = random.randint(0, seq_len-1) + source[i, :, offset:] += reals[j, :, :-offset] + if i == j: + # If this is the real one, shift the reals as well to ensure alignment + new_reals = torch.zeros_like(reals[i]) + new_reals[:, offset:] = reals[i, :, :-offset] + reals[i] = new_reals + sources_added += 1 + + return source + +class DiffusionPriorTrainingWrapper(pl.LightningModule): + ''' + Wrapper for training a diffusion prior for inverse problems + Prior types: + mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + ema_copy = None, + prior_type: PriorType = PriorType.MonoToStereo, + use_reconstruction_loss: bool = False, + log_loss_info: bool = False, + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.log_loss_info = log_loss_info + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.use_reconstruction_loss = use_reconstruction_loss + + if use_reconstruction_loss: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + sample_rate = model.sample_rate + + stft_loss_args = { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + } + + out_channels = model.io_channels + + self.audio_out_channels = out_channels + + if model.pretransform is not None: + out_channels = model.pretransform.io_channels + + if self.audio_out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Add left and right channel reconstruction losses in addition to the sum and difference + self.loss_modules += [ + AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05), + AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05), + ] + + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + self.loss_modules.append( + AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + ) + + self.losses = MultiLoss(loss_modules) + + self.prior_type = prior_type + + def configure_optimizers(self): + return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + loss_info["audio_reals"] = reals + + if self.prior_type == PriorType.MonoToStereo: + source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device) + loss_info["audio_reals_mono"] = source + else: + raise ValueError(f"Unknown prior type {self.prior_type}") + + if self.diffusion.pretransform is not None: + with torch.no_grad(): + reals = self.diffusion.pretransform.encode(reals) + + if self.prior_type in [PriorType.MonoToStereo]: + source = self.diffusion.pretransform.encode(source) + + if self.diffusion.conditioner is not None: + with torch.cuda.amp.autocast(): + conditioning = self.diffusion.conditioner(metadata, self.device) + else: + conditioning = {} + + loss_info["reals"] = reals + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(reals) + noised_reals = reals * alphas + noise * sigmas + targets = noise * alphas - reals * sigmas + + with torch.cuda.amp.autocast(): + + conditioning['source'] = [source] + + v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1) + + loss_info.update({ + "v": v, + "targets": targets + }) + + if self.use_reconstruction_loss: + pred = noised_reals * alphas - v * sigmas + + loss_info["pred"] = pred + + if self.diffusion.pretransform is not None: + pred = self.diffusion.pretransform.decode(pred) + loss_info["audio_pred"] = pred + + if self.audio_out_channels == 2: + loss_info["pred_left"] = pred[:, 0:1, :] + loss_info["pred_right"] = pred[:, 1:2, :] + loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :] + loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :] + + loss, losses = self.losses(loss_info) + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(v, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': reals.std() + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + #model = self.diffusion_ema.ema_model + model = self.diffusion + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +class DiffusionPriorDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_reals, metadata = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + + encoder_input = demo_reals + + if module.diffusion.conditioner is not None: + with torch.cuda.amp.autocast(): + conditioning_tensors = module.diffusion.conditioner(metadata, module.device) + + else: + conditioning_tensors = {} + + + with torch.no_grad() and torch.cuda.amp.autocast(): + if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1: + source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device) + + if module.diffusion.pretransform is not None: + encoder_input = module.diffusion.pretransform.encode(encoder_input) + source_input = module.diffusion.pretransform.encode(source) + else: + source_input = source + + conditioning_tensors['source'] = [source_input] + + fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + #Log the source + filename = f'source_{trainer.global_step:08}.wav' + source = rearrange(source, 'b d n -> d (b n)') + source = source.to(torch.float32).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, source, self.sample_rate) + + log_dict[f'source'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Source') + + log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source)) + + trainer.logger.experiment.log(log_dict) \ No newline at end of file diff --git a/stable_audio_tools/training/factory.py b/stable_audio_tools/training/factory.py new file mode 100644 index 0000000..c3216d1 --- /dev/null +++ b/stable_audio_tools/training/factory.py @@ -0,0 +1,240 @@ +import torch +from torch.nn import Parameter +from ..models.factory import create_model_from_config + +def create_training_wrapper_from_config(model_config, model): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + + if model_type == 'autoencoder': + from .autoencoders import AutoencoderTrainingWrapper + + ema_copy = None + + if training_config.get("use_ema", False): + ema_copy = create_model_from_config(model_config) + ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + use_ema = training_config.get("use_ema", False) + + latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) + + teacher_model = training_config.get("teacher_model", None) + if teacher_model is not None: + teacher_model = create_model_from_config(teacher_model) + teacher_model = teacher_model.eval().requires_grad_(False) + + teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) + if teacher_model_ckpt is not None: + teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) + else: + raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") + + return AutoencoderTrainingWrapper( + model, + lr=training_config["learning_rate"], + warmup_steps=training_config.get("warmup_steps", 0), + encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), + sample_rate=model_config["sample_rate"], + loss_config=training_config.get("loss_configs", None), + optimizer_configs=training_config.get("optimizer_configs", None), + use_ema=use_ema, + ema_copy=ema_copy if use_ema else None, + force_input_mono=training_config.get("force_input_mono", False), + latent_mask_ratio=latent_mask_ratio, + teacher_model=teacher_model + ) + elif model_type == 'diffusion_uncond': + from .diffusion import DiffusionUncondTrainingWrapper + return DiffusionUncondTrainingWrapper( + model, + lr=training_config["learning_rate"], + pre_encoded=training_config.get("pre_encoded", False), + ) + elif model_type == 'diffusion_cond': + from .diffusion import DiffusionCondTrainingWrapper + return DiffusionCondTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + mask_padding=training_config.get("mask_padding", False), + mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), + use_ema = training_config.get("use_ema", True), + log_loss_info=training_config.get("log_loss_info", False), + optimizer_configs=training_config.get("optimizer_configs", None), + pre_encoded=training_config.get("pre_encoded", False), + cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), + timestep_sampler = training_config.get("timestep_sampler", "uniform") + ) + elif model_type == 'diffusion_prior': + from .diffusion import DiffusionPriorTrainingWrapper + from ..models.diffusion_prior import PriorType + + ema_copy = create_model_from_config(model_config) + + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + prior_type = training_config.get("prior_type", "mono_stereo") + + if prior_type == "mono_stereo": + prior_type_enum = PriorType.MonoToStereo + else: + raise ValueError(f"Unknown prior type: {prior_type}") + + return DiffusionPriorTrainingWrapper( + model, + lr=training_config["learning_rate"], + ema_copy=ema_copy, + prior_type=prior_type_enum, + log_loss_info=training_config.get("log_loss_info", False), + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), + ) + elif model_type == 'diffusion_cond_inpaint': + from .diffusion import DiffusionCondInpaintTrainingWrapper + return DiffusionCondInpaintTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + max_mask_segments = training_config.get("max_mask_segments", 10), + log_loss_info=training_config.get("log_loss_info", False), + optimizer_configs=training_config.get("optimizer_configs", None), + use_ema=training_config.get("use_ema", True), + pre_encoded=training_config.get("pre_encoded", False), + cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), + timestep_sampler = training_config.get("timestep_sampler", "uniform") + ) + elif model_type == 'diffusion_autoencoder': + from .diffusion import DiffusionAutoencoderTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + return DiffusionAutoencoderTrainingWrapper( + model, + ema_copy=ema_copy, + lr=training_config["learning_rate"], + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) + ) + elif model_type == 'lm': + from .lm import AudioLanguageModelTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + return AudioLanguageModelTrainingWrapper( + model, + ema_copy=ema_copy, + lr=training_config.get("learning_rate", None), + use_ema=training_config.get("use_ema", False), + optimizer_configs=training_config.get("optimizer_configs", None), + pre_encoded=training_config.get("pre_encoded", False), + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_demo_callback_from_config(model_config, **kwargs): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + + demo_config = training_config.get("demo", {}) + + if model_type == 'autoencoder': + from .autoencoders import AutoencoderDemoCallback + return AutoencoderDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == 'diffusion_uncond': + from .diffusion import DiffusionUncondDemoCallback + return DiffusionUncondDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_rate=model_config["sample_rate"] + ) + elif model_type == "diffusion_autoencoder": + from .diffusion import DiffusionAutoencoderDemoCallback + return DiffusionAutoencoderDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_prior": + from .diffusion import DiffusionPriorDemoCallback + return DiffusionPriorDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_cond": + from .diffusion import DiffusionCondDemoCallback + + return DiffusionCondDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_steps=demo_config.get("demo_steps", 250), + num_demos=demo_config["num_demos"], + demo_cfg_scales=demo_config["demo_cfg_scales"], + demo_conditioning=demo_config.get("demo_cond", {}), + demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False), + display_audio_cond=demo_config.get("display_audio_cond", False), + ) + elif model_type == "diffusion_cond_inpaint": + from .diffusion import DiffusionCondInpaintDemoCallback + + return DiffusionCondInpaintDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_steps=demo_config.get("demo_steps", 250), + demo_cfg_scales=demo_config["demo_cfg_scales"], + **kwargs + ) + + elif model_type == "lm": + from .lm import AudioLanguageModelDemoCallback + + return AudioLanguageModelDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), + demo_conditioning=demo_config.get("demo_cond", None), + num_demos=demo_config.get("num_demos", 8), + **kwargs + ) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') \ No newline at end of file diff --git a/stable_audio_tools/training/lm.py b/stable_audio_tools/training/lm.py new file mode 100644 index 0000000..e1fa9f7 --- /dev/null +++ b/stable_audio_tools/training/lm.py @@ -0,0 +1,267 @@ +import pytorch_lightning as pl +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb + +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ..models.lm import AudioLanguageModelWrapper +from .utils import create_optimizer_from_config, create_scheduler_from_config + +class AudioLanguageModelTrainingWrapper(pl.LightningModule): + def __init__( + self, + model: AudioLanguageModelWrapper, + lr = 1e-4, + use_ema=False, + ema_copy=None, + optimizer_configs: dict = None, + pre_encoded=False + ): + super().__init__() + + self.model = model + + self.model.pretransform.requires_grad_(False) + + self.model_ema = None + if use_ema: + self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "lm": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1 + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + lm_opt_config = self.optimizer_configs['lm'] + opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) + + if "scheduler" in lm_opt_config: + sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) + sched_lm_config = { + "scheduler": sched_lm, + "interval": "step" + } + return [opt_lm], [sched_lm_config] + + return [opt_lm] + + # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license + # License can be found in LICENSES/LICENSE_META.txt + + def _compute_cross_entropy( + self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: + """Compute cross entropy between multi-codebook targets and model's logits. + The cross entropy is computed per codebook to provide codebook-level cross entropy. + Valid timesteps for each of the codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + Returns: + ce (torch.Tensor): Cross entropy averaged over the codebooks + ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). + """ + B, K, T = targets.shape + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + ce = torch.zeros([], device=targets.device) + ce_per_codebook: tp.List[torch.Tensor] = [] + for k in range(K): + logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] + ce_targets = targets_k[mask_k] + ce_logits = logits_k[mask_k] + q_ce = F.cross_entropy(ce_logits, ce_targets) + ce += q_ce + ce_per_codebook.append(q_ce.detach()) + # average cross entropy across codebooks + ce = ce / K + return ce, ce_per_codebook + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if not self.pre_encoded: + codes = self.model.pretransform.tokenize(reals) + else: + codes = reals + + padding_masks = [] + for md in metadata: + if md["padding_mask"].ndim == 1: + padding_masks.append(md["padding_mask"]) + else: + padding_masks.append(md["padding_mask"][0]) + + padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) + + # Interpolate padding masks to the same length as the codes + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() + + condition_tensors = None + + # If the model is conditioned, get the conditioning tensors + if self.model.conditioner is not None: + condition_tensors = self.model.conditioner(metadata, self.device) + + lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) + + logits = lm_output.logits # [b, k, t, c] + logits_mask = lm_output.mask # [b, k, t] + + logits_mask = logits_mask & padding_masks + + cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) + + loss = cross_entropy + + log_dict = { + 'train/loss': loss.detach(), + 'train/cross_entropy': cross_entropy.detach(), + 'train/perplexity': torch.exp(cross_entropy).detach(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for k, ce_q in enumerate(cross_entropy_per_codebook): + log_dict[f'cross_entropy_q{k + 1}'] = ce_q + log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.model_ema is not None: + self.model_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.model_ema.ema_model if self.model_ema is not None else self.model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AudioLanguageModelDemoCallback(pl.Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + **kwargs + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio + + #demo_reals = batch[0][:self.num_demos] + + # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + # demo_reals = demo_reals[0] + + #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) + + ##Limit to first 50 tokens + #demo_reals_tokens = demo_reals_tokens[:, :, :50] + + try: + print("Getting conditioning") + + for cfg_scale in self.demo_cfg_scales: + + model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model + + print(f"Generating demo for cfg scale {cfg_scale}") + fakes = model.generate_audio( + batch_size=self.num_demos, + max_gen_len=demo_length_tokens, + conditioning=self.demo_conditioning, + #init_data = demo_reals_tokens, + cfg_scale=cfg_scale, + temp=1.0, + top_p=0.95 + ) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes / fakes.abs().max() + fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() \ No newline at end of file diff --git a/stable_audio_tools/training/losses/__init__.py b/stable_audio_tools/training/losses/__init__.py new file mode 100644 index 0000000..37fdea0 --- /dev/null +++ b/stable_audio_tools/training/losses/__init__.py @@ -0,0 +1 @@ +from .losses import * \ No newline at end of file diff --git a/stable_audio_tools/training/losses/auraloss.py b/stable_audio_tools/training/losses/auraloss.py new file mode 100644 index 0000000..9ab5405 --- /dev/null +++ b/stable_audio_tools/training/losses/auraloss.py @@ -0,0 +1,607 @@ +# Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0 +# You can find the license at LICENSES/LICENSE_AURALOSS.txt + +import torch +import numpy as np +from typing import List, Any +import scipy.signal + +def apply_reduction(losses, reduction="none"): + """Apply reduction to collection of losses.""" + if reduction == "mean": + losses = losses.mean() + elif reduction == "sum": + losses = losses.sum() + return losses + +def get_window(win_type: str, win_length: int): + """Return a window function. + + Args: + win_type (str): Window type. Can either be one of the window function provided in PyTorch + ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] + or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). + win_length (int): Window length + + Returns: + win: The window as a 1D torch tensor + """ + + try: + win = getattr(torch, win_type)(win_length) + except: + win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length)) + + return win + +class SumAndDifference(torch.nn.Module): + """Sum and difference signal extraction module.""" + + def __init__(self): + """Initialize sum and difference extraction module.""" + super(SumAndDifference, self).__init__() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, #channels, #samples). + Returns: + Tensor: Sum signal. + Tensor: Difference signal. + """ + if not (x.size(1) == 2): # inputs must be stereo + raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).") + + sum_sig = self.sum(x).unsqueeze(1) + diff_sig = self.diff(x).unsqueeze(1) + + return sum_sig, diff_sig + + @staticmethod + def sum(x): + return x[:, 0, :] + x[:, 1, :] + + @staticmethod + def diff(x): + return x[:, 0, :] - x[:, 1, :] + + +class FIRFilter(torch.nn.Module): + """FIR pre-emphasis filtering module. + + Args: + filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp" + coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85 + ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101 + plot (bool): Plot the magnitude respond of the filter. Default: False + + Based upon the perceptual loss pre-empahsis filters proposed by + [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922). + + A-weighting filter - "aw" + First-order highpass - "hp" + Folded differentiator - "fd" + + Note that the default coefficeint value of 0.85 is optimized for + a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. + """ + + def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False): + """Initilize FIR pre-emphasis filtering module.""" + super(FIRFilter, self).__init__() + self.filter_type = filter_type + self.coef = coef + self.fs = fs + self.ntaps = ntaps + self.plot = plot + + import scipy.signal + + if ntaps % 2 == 0: + raise ValueError(f"ntaps must be odd (ntaps={ntaps}).") + + if filter_type == "hp": + self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1) + elif filter_type == "fd": + self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1) + elif filter_type == "aw": + # Definition of analog A-weighting filter according to IEC/CD 1672. + f1 = 20.598997 + f2 = 107.65265 + f3 = 737.86223 + f4 = 12194.217 + A1000 = 1.9997 + + NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0] + DENs = np.polymul( + [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2], + [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2], + ) + DENs = np.polymul( + np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2] + ) + + # convert analog filter to digital filter + b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs) + + # compute the digital filter frequency response + w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) + + # then we fit to 101 tap FIR filter with least squares + taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs) + + # now implement this digital FIR filter as a Conv1d layer + self.fir = torch.nn.Conv1d( + 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2 + ) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1) + + if plot: + from .plotting import compare_filters + compare_filters(b, a, taps, fs=fs) + + def forward(self, input, target): + """Calculate forward propagation. + Args: + input (Tensor): Predicted signal (B, #channels, #samples). + target (Tensor): Groundtruth signal (B, #channels, #samples). + Returns: + Tensor: Filtered signal. + """ + input = torch.nn.functional.conv1d( + input, self.fir.weight.data, padding=self.ntaps // 2 + ) + target = torch.nn.functional.conv1d( + target, self.fir.weight.data, padding=self.ntaps // 2 + ) + return input, target + +class SpectralConvergenceLoss(torch.nn.Module): + """Spectral convergence loss module. + + See [Arik et al., 2018](https://arxiv.org/abs/1808.06719). + """ + + def __init__(self): + super(SpectralConvergenceLoss, self).__init__() + + def forward(self, x_mag, y_mag): + return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean() + +class STFTMagnitudeLoss(torch.nn.Module): + """STFT magnitude loss module. + + See [Arik et al., 2018](https://arxiv.org/abs/1808.06719) + and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1) + + Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the + compression strength (larger value results in more compression), and `log_eps` can be used + to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive + output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression. + + Args: + log (bool, optional): Log-scale the STFT magnitudes, + or use linear scale. Default: True + log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm. + Default: 0.0 + log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm. + Default: 1.0 + distance (str, optional): Distance function ["L1", "L2"]. Default: "L1" + reduction (str, optional): Reduction of the loss elements. Default: "mean" + """ + + def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"): + super(STFTMagnitudeLoss, self).__init__() + + self.log = log + self.log_eps = log_eps + self.log_fac = log_fac + + if distance == "L1": + self.distance = torch.nn.L1Loss(reduction=reduction) + elif distance == "L2": + self.distance = torch.nn.MSELoss(reduction=reduction) + else: + raise ValueError(f"Invalid distance: '{distance}'.") + + def forward(self, x_mag, y_mag): + if self.log: + x_mag = torch.log(self.log_fac * x_mag + self.log_eps) + y_mag = torch.log(self.log_fac * y_mag + self.log_eps) + return self.distance(x_mag, y_mag) + + +class STFTLoss(torch.nn.Module): + """STFT loss module. + + See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472). + + Args: + fft_size (int, optional): FFT size in samples. Default: 1024 + hop_size (int, optional): Hop size of the FFT in samples. Default: 256 + win_length (int, optional): Length of the FFT analysis window. Default: 1024 + window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch + ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] + or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). + Default: 'hann_window' + w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 + w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 + w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 + w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 + sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None + scale (str, optional): Optional frequency scaling method, options include: + ['mel', 'chroma'] + Default: None + n_bins (int, optional): Number of scaling frequency bins. Default: None. + perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False + scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False + eps (float, optional): Small epsilon value for stablity. Default: 1e-8 + output (str, optional): Format of the loss returned. + 'loss' : Return only the raw, aggregate loss term. + 'full' : Return the raw loss, plus intermediate loss terms. + Default: 'loss' + reduction (str, optional): Specifies the reduction to apply to the output: + 'none': no reduction will be applied, + 'mean': the sum of the output will be divided by the number of elements in the output, + 'sum': the output will be summed. + Default: 'mean' + mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms. + device (str, optional): Place the filterbanks on specified device. Default: None + + Returns: + loss: + Aggreate loss term. Only returned if output='loss'. By default. + loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss: + Aggregate and intermediate loss terms. Only returned if output='full'. + """ + + def __init__( + self, + fft_size: int = 1024, + hop_size: int = 256, + win_length: int = 1024, + window: str = "hann_window", + w_sc: float = 1.0, + w_log_mag: float = 1.0, + w_lin_mag: float = 0.0, + w_phs: float = 0.0, + sample_rate: float = None, + scale: str = None, + n_bins: int = None, + perceptual_weighting: bool = False, + scale_invariance: bool = False, + eps: float = 1e-8, + output: str = "loss", + reduction: str = "mean", + mag_distance: str = "L1", + device: Any = None, + **kwargs + ): + super().__init__() + self.fft_size = fft_size + self.hop_size = hop_size + self.win_length = win_length + self.window = get_window(window, win_length) + self.w_sc = w_sc + self.w_log_mag = w_log_mag + self.w_lin_mag = w_lin_mag + self.w_phs = w_phs + self.sample_rate = sample_rate + self.scale = scale + self.n_bins = n_bins + self.perceptual_weighting = perceptual_weighting + self.scale_invariance = scale_invariance + self.eps = eps + self.output = output + self.reduction = reduction + self.mag_distance = mag_distance + self.device = device + + self.phs_used = bool(self.w_phs) + + self.spectralconv = SpectralConvergenceLoss() + self.logstft = STFTMagnitudeLoss( + log=True, + reduction=reduction, + distance=mag_distance, + **kwargs + ) + self.linstft = STFTMagnitudeLoss( + log=False, + reduction=reduction, + distance=mag_distance, + **kwargs + ) + + # setup mel filterbank + if scale is not None: + try: + import librosa.filters + except Exception as e: + print(e) + print("Try `pip install auraloss[all]`.") + + if self.scale == "mel": + assert sample_rate != None # Must set sample rate to use mel scale + assert n_bins <= fft_size # Must be more FFT bins than Mel bins + fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) + fb = torch.tensor(fb).unsqueeze(0) + + elif self.scale == "chroma": + assert sample_rate != None # Must set sample rate to use chroma scale + assert n_bins <= fft_size # Must be more FFT bins than chroma bins + fb = librosa.filters.chroma( + sr=sample_rate, n_fft=fft_size, n_chroma=n_bins + ) + + else: + raise ValueError( + f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'." + ) + + self.register_buffer("fb", fb) + + if scale is not None and device is not None: + self.fb = self.fb.to(self.device) # move filterbank to device + + if self.perceptual_weighting: + if sample_rate is None: + raise ValueError( + f"`sample_rate` must be supplied when `perceptual_weighting = True`." + ) + self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) + + def stft(self, x): + """Perform STFT. + Args: + x (Tensor): Input signal tensor (B, T). + + Returns: + Tensor: x_mag, x_phs + Magnitude and phase spectra (B, fft_size // 2 + 1, frames). + """ + x_stft = torch.stft( + x, + self.fft_size, + self.hop_size, + self.win_length, + self.window, + return_complex=True, + ) + x_mag = torch.sqrt( + torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps) + ) + + # torch.angle is expensive, so it is only evaluated if the values are used in the loss + if self.phs_used: + x_phs = torch.angle(x_stft) + else: + x_phs = None + + return x_mag, x_phs + + def forward(self, input: torch.Tensor, target: torch.Tensor): + bs, chs, seq_len = input.size() + + if self.perceptual_weighting: # apply optional A-weighting via FIR filter + # since FIRFilter only support mono audio we will move channels to batch dim + input = input.view(bs * chs, 1, -1) + target = target.view(bs * chs, 1, -1) + + # now apply the filter to both + self.prefilter.to(input.device) + input, target = self.prefilter(input, target) + + # now move the channels back + input = input.view(bs, chs, -1) + target = target.view(bs, chs, -1) + + # compute the magnitude and phase spectra of input and target + self.window = self.window.to(input.device) + + x_mag, x_phs = self.stft(input.view(-1, input.size(-1))) + y_mag, y_phs = self.stft(target.view(-1, target.size(-1))) + + # apply relevant transforms + if self.scale is not None: + self.fb = self.fb.to(input.device) + x_mag = torch.matmul(self.fb, x_mag) + y_mag = torch.matmul(self.fb, y_mag) + + # normalize scales + if self.scale_invariance: + alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1])) + y_mag = y_mag * alpha.unsqueeze(-1) + + # compute loss terms + sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0 + log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0 + lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0 + phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0 + + # combine loss terms + loss = ( + (self.w_sc * sc_mag_loss) + + (self.w_log_mag * log_mag_loss) + + (self.w_lin_mag * lin_mag_loss) + + (self.w_phs * phs_loss) + ) + + loss = apply_reduction(loss, reduction=self.reduction) + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module. + + See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480) + + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str, optional): Window to apply before FFT, options include: + 'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] + Default: 'hann_window' + w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 + w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 + w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 + w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 + sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None + scale (str, optional): Optional frequency scaling method, options include: + ['mel', 'chroma'] + Default: None + n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None. + scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False + """ + + def __init__( + self, + fft_sizes: List[int] = [1024, 2048, 512], + hop_sizes: List[int] = [120, 240, 50], + win_lengths: List[int] = [600, 1200, 240], + window: str = "hann_window", + w_sc: float = 1.0, + w_log_mag: float = 1.0, + w_lin_mag: float = 0.0, + w_phs: float = 0.0, + sample_rate: float = None, + scale: str = None, + n_bins: int = None, + perceptual_weighting: bool = False, + scale_invariance: bool = False, + **kwargs, + ): + super().__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all + self.fft_sizes = fft_sizes + self.hop_sizes = hop_sizes + self.win_lengths = win_lengths + + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [ + STFTLoss( + fs, + ss, + wl, + window, + w_sc, + w_log_mag, + w_lin_mag, + w_phs, + sample_rate, + scale, + n_bins, + perceptual_weighting, + scale_invariance, + **kwargs, + ) + ] + + def forward(self, x, y): + mrstft_loss = 0.0 + sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], [] + + for f in self.stft_losses: + if f.output == "full": # extract just first term + tmp_loss = f(x, y) + mrstft_loss += tmp_loss[0] + sc_mag_loss.append(tmp_loss[1]) + log_mag_loss.append(tmp_loss[2]) + lin_mag_loss.append(tmp_loss[3]) + phs_loss.append(tmp_loss[4]) + else: + mrstft_loss += f(x, y) + + mrstft_loss /= len(self.stft_losses) + + if f.output == "loss": + return mrstft_loss + else: + return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss + + +class SumAndDifferenceSTFTLoss(torch.nn.Module): + """Sum and difference sttereo STFT loss module. + + See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291) + + Args: + fft_sizes (List[int]): List of FFT sizes. + hop_sizes (List[int]): List of hop sizes. + win_lengths (List[int]): List of window lengths. + window (str, optional): Window function type. + w_sum (float, optional): Weight of the sum loss component. Default: 1.0 + w_diff (float, optional): Weight of the difference loss component. Default: 1.0 + perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False + mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False + n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128 + sample_rate (float, optional): Audio sample rate. Default: None + output (str, optional): Format of the loss returned. + 'loss' : Return only the raw, aggregate loss term. + 'full' : Return the raw loss, plus intermediate loss terms. + Default: 'loss' + """ + + def __init__( + self, + fft_sizes: List[int], + hop_sizes: List[int], + win_lengths: List[int], + window: str = "hann_window", + w_sum: float = 1.0, + w_diff: float = 1.0, + output: str = "loss", + **kwargs, + ): + super().__init__() + self.sd = SumAndDifference() + self.w_sum = w_sum + self.w_diff = w_diff + self.output = output + self.mrstft = MultiResolutionSTFTLoss( + fft_sizes, + hop_sizes, + win_lengths, + window, + **kwargs, + ) + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """This loss function assumes batched input of stereo audio in the time domain. + + Args: + input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). + target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). + + Returns: + loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. + loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): + Aggregate and intermediate loss terms. Only returned if output='full'. + """ + assert input.shape == target.shape # must have same shape + bs, chs, seq_len = input.size() + + # compute sum and difference signals for both + input_sum, input_diff = self.sd(input) + target_sum, target_diff = self.sd(target) + + # compute error in STFT domain + sum_loss = self.mrstft(input_sum, target_sum) + diff_loss = self.mrstft(input_diff, target_diff) + loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2 + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, sum_loss, diff_loss \ No newline at end of file diff --git a/stable_audio_tools/training/losses/losses.py b/stable_audio_tools/training/losses/losses.py new file mode 100644 index 0000000..15d05ac --- /dev/null +++ b/stable_audio_tools/training/losses/losses.py @@ -0,0 +1,101 @@ +import typing as tp + +from torch.nn import functional as F +from torch import nn +import torch +class LossModule(nn.Module): + def __init__(self, name: str, weight: float = 1.0): + super().__init__() + + self.name = name + self.weight = weight + + def forward(self, info, *args, **kwargs): + raise NotImplementedError + +class ValueLoss(LossModule): + def __init__(self, key: str, name, weight: float = 1.0): + super().__init__(name=name, weight=weight) + + self.key = key + + def forward(self, info): + return self.weight * info[self.key] + +class L1Loss(LossModule): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'): + super().__init__(name=name, weight=weight) + + self.key_a = key_a + self.key_b = key_b + + self.mask_key = mask_key + + def forward(self, info): + mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none') + + if self.mask_key is not None and self.mask_key in info: + mse_loss = mse_loss[info[self.mask_key]] + + mse_loss = mse_loss.mean() + + return self.weight * mse_loss + +class MSELoss(LossModule): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'): + super().__init__(name=name, weight=weight) + + self.key_a = key_a + self.key_b = key_b + + self.mask_key = mask_key + + def forward(self, info): + mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none') + + if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None: + mask = info[self.mask_key] + + if mask.ndim == 2 and mse_loss.ndim == 3: + mask = mask.unsqueeze(1) + + if mask.shape[1] != mse_loss.shape[1]: + mask = mask.repeat(1, mse_loss.shape[1], 1) + + mse_loss = mse_loss[mask] + + mse_loss = mse_loss.mean() + + return self.weight * mse_loss + +class AuralossLoss(LossModule): + def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1): + super().__init__(name, weight) + + self.auraloss_module = auraloss_module + + self.input_key = input_key + self.target_key = target_key + + def forward(self, info): + loss = self.auraloss_module(info[self.input_key], info[self.target_key]) + + return self.weight * loss + +class MultiLoss(nn.Module): + def __init__(self, losses: tp.List[LossModule]): + super().__init__() + + self.losses = nn.ModuleList(losses) + + def forward(self, info): + total_loss = 0 + + losses = {} + + for loss_module in self.losses: + module_loss = loss_module(info) + total_loss += module_loss + losses[loss_module.name] = module_loss + + return total_loss, losses \ No newline at end of file diff --git a/stable_audio_tools/training/utils.py b/stable_audio_tools/training/utils.py new file mode 100644 index 0000000..38a3fcc --- /dev/null +++ b/stable_audio_tools/training/utils.py @@ -0,0 +1,111 @@ +import torch +import os + +def get_rank(): + """Get rank of current process.""" + + print(os.environ.keys()) + + if "SLURM_PROCID" in os.environ: + return int(os.environ["SLURM_PROCID"]) + + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return 0 + + return torch.distributed.get_rank() + +class InverseLR(torch.optim.lr_scheduler._LRScheduler): + """Implements an inverse decay learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. + inv_gamma is the number of steps/epochs required for the learning rate to decay to + (1 / 2)**power of its original value. + Args: + optimizer (Optimizer): Wrapped optimizer. + inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. + power (float): Exponential factor of learning rate decay. Default: 1. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + final_lr (float): The final learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., + last_epoch=-1, verbose=False): + self.inv_gamma = inv_gamma + self.power = power + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.final_lr = final_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + import warnings + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power + return [warmup * max(self.final_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + +def copy_state_dict(model, state_dict): + """Load state_dict to model, but only for keys that match exactly. + + Args: + model (nn.Module): model to load state_dict. + state_dict (OrderedDict): state_dict to load. + """ + model_state_dict = model.state_dict() + for key in state_dict: + if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: + if isinstance(state_dict[key], torch.nn.Parameter): + # backwards compatibility for serialized parameters + state_dict[key] = state_dict[key].data + model_state_dict[key] = state_dict[key] + + model.load_state_dict(model_state_dict, strict=False) + +def create_optimizer_from_config(optimizer_config, parameters): + """Create optimizer from config. + + Args: + parameters (iterable): parameters to optimize. + optimizer_config (dict): optimizer config. + + Returns: + torch.optim.Optimizer: optimizer. + """ + + optimizer_type = optimizer_config["type"] + + if optimizer_type == "FusedAdam": + from deepspeed.ops.adam import FusedAdam + optimizer = FusedAdam(parameters, **optimizer_config["config"]) + else: + optimizer_fn = getattr(torch.optim, optimizer_type) + optimizer = optimizer_fn(parameters, **optimizer_config["config"]) + return optimizer + +def create_scheduler_from_config(scheduler_config, optimizer): + """Create scheduler from config. + + Args: + scheduler_config (dict): scheduler config. + optimizer (torch.optim.Optimizer): optimizer. + + Returns: + torch.optim.lr_scheduler._LRScheduler: scheduler. + """ + if scheduler_config["type"] == "InverseLR": + scheduler_fn = InverseLR + else: + scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) + scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) + return scheduler \ No newline at end of file