From 0251bea97ab8d2e4e4c9cfed19c7972daa5b1530 Mon Sep 17 00:00:00 2001 From: ZeyueT Date: Thu, 3 Apr 2025 03:34:05 +0800 Subject: [PATCH] AudioX --- .gitignore | 1 - stable_audio_tools/data/__init__.py | 0 stable_audio_tools/data/dataset.py | 876 ++++++++++++++++++++++++++++ stable_audio_tools/data/utils.py | 199 +++++++ 4 files changed, 1075 insertions(+), 1 deletion(-) create mode 100644 stable_audio_tools/data/__init__.py create mode 100644 stable_audio_tools/data/dataset.py create mode 100644 stable_audio_tools/data/utils.py diff --git a/.gitignore b/.gitignore index a380d64..597f849 100644 --- a/.gitignore +++ b/.gitignore @@ -173,6 +173,5 @@ logs/ log/ saved_ckpt/ wandb/ -data/ demo_result/ model/ \ No newline at end of file diff --git a/stable_audio_tools/data/__init__.py b/stable_audio_tools/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py new file mode 100644 index 0000000..1395322 --- /dev/null +++ b/stable_audio_tools/data/dataset.py @@ -0,0 +1,876 @@ +import importlib +import numpy as np +import io +import os +import posixpath +import random +import re +import subprocess +import time +import torch +import torchaudio +import webdataset as wds + +from aeiou.core import is_silence +from os import path +from pedalboard.io import AudioFile +from torchaudio import transforms as T +from typing import Optional, Callable, List +from torchdata.datapipes.iter import IterDataPipe, IterableWrapper +from torchdata.datapipes.iter import Prefetcher + +from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T +import json + + +import os +import datetime +from memory_profiler import profile + + +AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") + +# fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py + +def fast_scandir( + dir:str, # top-level directory at which to begin scanning + ext:list, # list of allowed file extensions, + #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB + ): + "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" + subfolders, files = [], [] + ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed + try: # hope to avoid 'permission denied' by this try + for f in os.scandir(dir): + try: # 'hope to avoid too many levels of symbolic links' error + if f.is_dir(): + subfolders.append(f.path) + elif f.is_file(): + file_ext = os.path.splitext(f.name)[1].lower() + is_hidden = os.path.basename(f.path).startswith(".") + + if file_ext in ext and not is_hidden: + files.append(f.path) + except: + pass + except: + pass + + for dir in list(subfolders): + sf, f = fast_scandir(dir, ext) + subfolders.extend(sf) + files.extend(f) + return subfolders, files + +def extract_audio_paths(jsonl_file, exts): + audio_paths = [] + video_paths = [] + text_prompts = [] + data_types = [] + with open(jsonl_file, 'r') as file: + for line in file: + try: + data = json.loads(line.strip()) + path = data.get('path', '') + video_path = data.get('video_path', '') + text_prompt = data.get('caption', '') + data_type = data.get('type', None) + if any(path.endswith(ext) for ext in exts): + audio_paths.append(path) + video_paths.append(video_path) + text_prompts.append(text_prompt) + data_types.append(data_type) + except json.JSONDecodeError: + print(f"Error decoding JSON line: {line.strip()}") + return audio_paths, video_paths, text_prompts, data_types + +def keyword_scandir( + dir: str, # top-level directory at which to begin scanning + ext: list, # list of allowed file extensions + keywords: list, # list of keywords to search for in the file name +): + "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" + subfolders, files = [], [] + # make keywords case insensitive + keywords = [keyword.lower() for keyword in keywords] + # add starting period to extensions if needed + ext = ['.'+x if x[0] != '.' else x for x in ext] + banned_words = ["paxheader", "__macosx"] + try: # hope to avoid 'permission denied' by this try + for f in os.scandir(dir): + try: # 'hope to avoid too many levels of symbolic links' error + if f.is_dir(): + subfolders.append(f.path) + elif f.is_file(): + is_hidden = f.name.split("/")[-1][0] == '.' + has_ext = os.path.splitext(f.name)[1].lower() in ext + name_lower = f.name.lower() + has_keyword = any( + [keyword in name_lower for keyword in keywords]) + has_banned = any( + [banned_word in name_lower for banned_word in banned_words]) + if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): + files.append(f.path) + except: + pass + except: + pass + + for dir in list(subfolders): + sf, f = keyword_scandir(dir, ext, keywords) + subfolders.extend(sf) + files.extend(f) + return subfolders, files + +def get_audio_filenames( + paths: list, # directories in which to search + keywords=None, + exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] +): + + "recursively get a list of audio filenames" + filenames = [] + video_filenames = [] + text_prompts = [] + data_types = [] + + if type(paths) is str: + paths = [paths] + + + if os.path.isdir(paths[0]): + for path in paths: # get a list of relevant filenames + if keywords is not None: + subfolders, files = keyword_scandir(path, exts, keywords) + else: + subfolders, files = fast_scandir(path, exts) + filenames.extend(files) + return filenames + + elif os.path.isfile(paths[0]): + assert paths[0].endswith('.jsonl') + for path in paths: + audio_paths, video_paths, text_prompt, data_type = extract_audio_paths(path, exts) + filenames.extend(audio_paths) + video_filenames.extend(video_paths) + text_prompts.extend(text_prompt) + data_types.extend(data_type) + + return filenames, video_filenames, text_prompts, data_types + + +class LocalDatasetConfig: + def __init__( + self, + id: str, + path: str, + video_fps: int, + custom_metadata_fn: Optional[Callable[[str], str]] = None + ): + self.id = id + self.path = path + self.video_fps = video_fps + self.custom_metadata_fn = custom_metadata_fn + + +# @profile +class SampleDataset(torch.utils.data.Dataset): + # @profile + def __init__( + self, + configs, + sample_size=65536, + sample_rate=48000, + keywords=None, + random_crop=True, + force_channels="stereo", + video_fps=5 + ): + super().__init__() + self.filenames = [] + self.video_filenames = [] + self.text_prompts = [] + self.data_types = [] + + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.root_paths = [] + + self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) + + self.force_channels = force_channels + + self.encoding = torch.nn.Sequential( + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + + self.sr = sample_rate + + self.custom_metadata_fns = {} + + for config in configs: + self.video_fps = config.video_fps + + self.root_paths.append(config.path) + audio_files, video_files, text_prompt, data_types = get_audio_filenames(config.path, keywords) + + self.filenames.extend(audio_files) + self.video_filenames.extend(video_files) + self.text_prompts.extend(text_prompt) + self.data_types.extend(data_types) + if config.custom_metadata_fn is not None: + self.custom_metadata_fns[config.path] = config.custom_metadata_fn + + print(f'Found {len(self.filenames)} files') + + + def load_file(self, filename): + ext = filename.split(".")[-1] + + if ext == "mp3": + with AudioFile(filename) as f: + audio = f.read(f.frames) + audio = torch.from_numpy(audio) + in_sr = f.samplerate + else: + audio, in_sr = torchaudio.load(filename, format=ext) + + if in_sr != self.sr: + resample_tf = T.Resample(in_sr, self.sr) + audio = resample_tf(audio) + + return audio + + def __len__(self): + return len(self.filenames) + + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + video_filename = self.video_filenames[idx] + text_prompt = self.text_prompts[idx] + data_type = self.data_types[idx] + + try: + + start_time = time.time() + audio = self.load_file(audio_filename) + + + if data_type in ["text_condition-audio", "text_condition-music", + "video_condition-audio", "video_condition-music", + "text+video_condition-audio","text+video_condition-music"]: + if_audio_contition = False + audio_prompt = torch.zeros((2, self.sr * 10)) + elif data_type in ["audio_condition-audio", "audio_condition-music", + "uni_condition-audio", "uni_condition-music"]: + if_audio_contition = True + + if if_audio_contition: + audio_org = audio.clamp(-1, 1) + + + audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio) + + if self.augs is not None: + audio = self.augs(audio) + + audio = audio.clamp(-1, 1) + + if if_audio_contition: + if data_type.split("-")[-1] == "audio": + start_index = max(0, int((seconds_start) * self.sr)) + end_index = int((seconds_start+10) * self.sr) + audio_prompt = audio_org[:, start_index:end_index] + + elif data_type.split("-")[-1] == "music": + if seconds_start < 10: + start_index = 0 + end_index = int(10 * self.sr) + else: + start_index = max(0, int((seconds_start - 10) * self.sr)) + end_index = int(seconds_start * self.sr) + audio_prompt = audio_org[:, start_index:end_index] + + # Encode the file to assist in prediction + if self.encoding is not None: + audio = self.encoding(audio) + + info = {} + + + info["path"] = audio_filename + info["video_path"] = video_filename + info["text_prompt"] = text_prompt + info["audio_prompt"] = audio_prompt + info["data_type"] = data_type + + for root_path in self.root_paths: + if root_path in audio_filename: + info["relpath"] = path.relpath(audio_filename, root_path) + + info["timestamps"] = (t_start, t_end) + info["seconds_start"] = seconds_start + info["seconds_total"] = seconds_total + info["padding_mask"] = padding_mask + info["video_fps"] = self.video_fps + end_time = time.time() + + info["load_time"] = end_time - start_time + + for custom_md_path in self.custom_metadata_fns.keys(): + if os.path.isdir(custom_md_path): + if custom_md_path in audio_filename: + custom_metadata_fn = self.custom_metadata_fns[custom_md_path] + custom_metadata = custom_metadata_fn(info, audio) + info.update(custom_metadata) + elif os.path.isfile(custom_md_path): + custom_metadata_fn = self.custom_metadata_fns[custom_md_path] + custom_metadata = custom_metadata_fn(info, audio) + info.update(custom_metadata) + + if "__reject__" in info and info["__reject__"]: + return self[random.randrange(len(self))] + + file_name = audio_filename.split('/')[-1] + + return (audio, info) + except Exception as e: + print(f'Couldn\'t load file {audio_filename}: {e}') + return self[random.randrange(len(self))] + +def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if wds.tariterators.trace: + print( + prefix, + suffix, + current_sample.keys() if isinstance(current_sample, dict) else None, + ) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + if current_sample is None or prefix != current_sample["__key__"]: + if wds.tariterators.valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffix in current_sample: + print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if wds.tariterators.valid_sample(current_sample): + yield current_sample + +wds.tariterators.group_by_keys = group_by_keys + +# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py + +def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): + """ + Returns a list of full S3 paths to files in a given S3 bucket and directory path. + """ + # Ensure dataset_path ends with a trailing slash + if dataset_path != '' and not dataset_path.endswith('/'): + dataset_path += '/' + # Use posixpath to construct the S3 URL path + bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) + # Construct the `aws s3 ls` command + cmd = ['aws', 's3', 'ls', bucket_path] + + if profile is not None: + cmd.extend(['--profile', profile]) + + if recursive: + # Add the --recursive flag if requested + cmd.append('--recursive') + + # Run the `aws s3 ls` command and capture the output + run_ls = subprocess.run(cmd, capture_output=True, check=True) + # Split the output into lines and strip whitespace from each line + contents = run_ls.stdout.decode('utf-8').split('\n') + contents = [x.strip() for x in contents if x] + # Remove the timestamp from lines that begin with a timestamp + contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) + if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] + # Construct a full S3 path for each file in the contents list + contents = [posixpath.join(s3_url_prefix or '', x) + for x in contents if not x.endswith('/')] + # Apply the filter, if specified + if filter: + contents = [x for x in contents if filter in x] + # Remove redundant directory names in the S3 URL + if recursive: + # Get the main directory name from the S3 URL + main_dir = "/".join(bucket_path.split('/')[3:]) + # Remove the redundant directory names from each file path + contents = [x.replace(f'{main_dir}', '').replace( + '//', '/') for x in contents] + # Print debugging information, if requested + if debug: + print("contents = \n", contents) + # Return the list of S3 paths to files + return contents + + +def get_all_s3_urls( + names=[], # list of all valid [LAION AudioDataset] dataset names + # list of subsets you want from those datasets, e.g. ['train','valid'] + subsets=[''], + s3_url_prefix=None, # prefix for those dataset names + recursive=True, # recursively list all tar files in all subdirs + filter_str='tar', # only grab files with this substring + # print debugging info -- note: info displayed likely to change at dev's whims + debug=False, + profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} +): + "get urls of shards (tar files) for multiple datasets in one s3 bucket" + urls = [] + for name in names: + # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list + if s3_url_prefix is None: + contents_str = name + else: + # Construct the S3 path using the s3_url_prefix and the current name value + contents_str = posixpath.join(s3_url_prefix, name) + if debug: + print(f"get_all_s3_urls: {contents_str}:") + for subset in subsets: + subset_str = posixpath.join(contents_str, subset) + if debug: + print(f"subset_str = {subset_str}") + # Get the list of tar files in the current subset directory + profile = profiles.get(name, None) + tar_list = get_s3_contents( + subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) + for tar in tar_list: + # Escape spaces and parentheses in the tar filename for use in the shell command + tar = tar.replace(" ", "\ ").replace( + "(", "\(").replace(")", "\)") + # Construct the S3 path to the current tar file + s3_path = posixpath.join(name, subset, tar) + " -" + # Construct the AWS CLI command to download the current tar file + if s3_url_prefix is None: + request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" + else: + request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" + if profiles.get(name): + request_str += f" --profile {profiles.get(name)}" + if debug: + print("request_str = ", request_str) + # Add the constructed URL to the list of URLs + urls.append(request_str) + return urls + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + print(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +def is_valid_sample(sample): + has_json = "json" in sample + has_audio = "audio" in sample + is_silent = is_silence(sample["audio"]) + is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"] + + return has_json and has_audio and not is_silent and not is_rejected + +class S3DatasetConfig: + def __init__( + self, + id: str, + s3_path: str, + custom_metadata_fn: Optional[Callable[[str], str]] = None, + profile: Optional[str] = None, + ): + self.id = id + self.path = s3_path + self.custom_metadata_fn = custom_metadata_fn + self.profile = profile + self.urls = [] + + def load_data_urls(self): + self.urls = get_all_s3_urls( + names=[self.path], + s3_url_prefix=None, + recursive=True, + profiles={self.path: self.profile} if self.profile else {}, + ) + + return self.urls + +class LocalWebDatasetConfig: + def __init__( + self, + id: str, + path: str, + custom_metadata_fn: Optional[Callable[[str], str]] = None, + profile: Optional[str] = None, + ): + self.id = id + self.path = path + self.custom_metadata_fn = custom_metadata_fn + self.urls = [] + + def load_data_urls(self): + + self.urls = fast_scandir(self.path, ["tar"])[1] + + return self.urls + +def audio_decoder(key, value): + # Get file extension from key + ext = key.split(".")[-1] + + if ext in AUDIO_KEYS: + return torchaudio.load(io.BytesIO(value)) + else: + return None + +def collation_fn(samples): + batched = list(zip(*samples)) + result = [] + for b in batched: + if isinstance(b[0], (int, float)): + b = np.array(b) + elif isinstance(b[0], torch.Tensor): + b = torch.stack(b) + elif isinstance(b[0], np.ndarray): + b = np.array(b) + else: + b = b + result.append(b) + return result + +class WebDatasetDataLoader(): + def __init__( + self, + datasets: List[S3DatasetConfig], + batch_size, + sample_size, + sample_rate=48000, + num_workers=8, + epoch_steps=1000, + random_crop=True, + force_channels="stereo", + augment_phase=True, + **data_loader_kwargs + ): + + self.datasets = datasets + + self.sample_size = sample_size + self.sample_rate = sample_rate + self.random_crop = random_crop + self.force_channels = force_channels + self.augment_phase = augment_phase + + urls = [dataset.load_data_urls() for dataset in datasets] + + # Flatten the list of lists of URLs + urls = [url for dataset_urls in urls for url in dataset_urls] + + # Shuffle the urls + random.shuffle(urls) + + self.dataset = wds.DataPipeline( + wds.ResampledShards(urls), + wds.tarfile_to_samples(handler=log_and_continue), + wds.decode(audio_decoder, handler=log_and_continue), + wds.map(self.wds_preprocess, handler=log_and_continue), + wds.select(is_valid_sample), + wds.to_tuple("audio", "json", handler=log_and_continue), + #wds.shuffle(bufsize=1000, initial=5000), + wds.batched(batch_size, partial=False, collation_fn=collation_fn), + ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) + + self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs) + + def wds_preprocess(self, sample): + + found_key, rewrite_key = '', '' + for k, v in sample.items(): # print the all entries in dict + for akey in AUDIO_KEYS: + if k.endswith(akey): + # to rename long/weird key with its simpler counterpart + found_key, rewrite_key = k, akey + break + if '' != found_key: + break + if '' == found_key: # got no audio! + return None # try returning None to tell WebDataset to skip this one + + audio, in_sr = sample[found_key] + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate) + audio = resample_tf(audio) + + if self.sample_size is not None: + # Pad/crop and get the relative timestamp + pad_crop = PadCrop_Normalized_T( + self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) + audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop( + audio) + sample["json"]["seconds_start"] = seconds_start + sample["json"]["seconds_total"] = seconds_total + sample["json"]["padding_mask"] = padding_mask + else: + t_start, t_end = 0, 1 + + # Check if audio is length zero, initialize to a single zero if so + if audio.shape[-1] == 0: + audio = torch.zeros(1, 1) + + # Make the audio stereo and augment by randomly inverting phase + augs = torch.nn.Sequential( + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + PhaseFlipper() if self.augment_phase else torch.nn.Identity() + ) + + audio = augs(audio) + + sample["json"]["timestamps"] = (t_start, t_end) + + if "text" in sample["json"]: + sample["json"]["prompt"] = sample["json"]["text"] + + # Check for custom metadata functions + for dataset in self.datasets: + if dataset.custom_metadata_fn is None: + continue + + if dataset.path in sample["__url__"]: + custom_metadata = dataset.custom_metadata_fn(sample["json"], audio) + sample["json"].update(custom_metadata) + + if found_key != rewrite_key: # rename long/weird key with its simpler counterpart + del sample[found_key] + + sample["audio"] = audio + + # Add audio to the metadata as well for conditioning + sample["json"]["audio"] = audio + + return sample + +def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4, video_fps=5): + + dataset_type = dataset_config.get("dataset_type", None) + + assert dataset_type is not None, "Dataset type must be specified in dataset config" + + if audio_channels == 1: + force_channels = "mono" + else: + force_channels = "stereo" + + if dataset_type == "audio_dir": + + audio_dir_configs = dataset_config.get("datasets", None) + + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + + configs = [] + + for audio_dir_config in audio_dir_configs: + audio_dir_path = audio_dir_config.get("path", None) + assert audio_dir_path is not None, "Path must be set for local audio directory configuration" + + custom_metadata_fn = None + custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + configs.append( + LocalDatasetConfig( + id=audio_dir_config["id"], + path=audio_dir_path, + custom_metadata_fn=custom_metadata_fn, + video_fps=video_fps + ) + ) + + train_set = SampleDataset( + configs, + sample_rate=sample_rate, + sample_size=sample_size, + random_crop=dataset_config.get("random_crop", True), + force_channels=force_channels, + video_fps=video_fps + ) + + return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, + num_workers=num_workers, persistent_workers=True, pin_memory=False, drop_last=True, collate_fn=collation_fn) + + elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility + wds_configs = [] + + for wds_config in dataset_config["datasets"]: + + custom_metadata_fn = None + custom_metadata_module_path = wds_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + if "s3_path" in wds_config: + + wds_configs.append( + S3DatasetConfig( + id=wds_config["id"], + s3_path=wds_config["s3_path"], + custom_metadata_fn=custom_metadata_fn, + profile=wds_config.get("profile", None), + ) + ) + + elif "path" in wds_config: + + wds_configs.append( + LocalWebDatasetConfig( + id=wds_config["id"], + path=wds_config["path"], + custom_metadata_fn=custom_metadata_fn + ) + ) + + return WebDatasetDataLoader( + wds_configs, + sample_rate=sample_rate, + sample_size=sample_size, + batch_size=batch_size, + random_crop=dataset_config.get("random_crop", True), + num_workers=num_workers, + persistent_workers=True, + force_channels=force_channels, + epoch_steps=dataset_config.get("epoch_steps", 2000) + ).data_loader + + + + +def create_dataloader_from_config_valid(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4): + + + dataset_type = dataset_config.get("dataset_type", None) + + assert dataset_type is not None, "Dataset type must be specified in dataset config" + + if audio_channels == 1: + force_channels = "mono" + else: + force_channels = "stereo" + + if dataset_type == "audio_dir": + + audio_dir_configs = dataset_config.get("datasets", None) + + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + + configs = [] + + for audio_dir_config in audio_dir_configs: + audio_dir_path = audio_dir_config.get("path", None) + assert audio_dir_path is not None, "Path must be set for local audio directory configuration" + + custom_metadata_fn = None + custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + configs.append( + LocalDatasetConfig( + id=audio_dir_config["id"], + path=audio_dir_path, + custom_metadata_fn=custom_metadata_fn + ) + ) + + valid_set = SampleDataset( + configs, + sample_rate=sample_rate, + sample_size=sample_size, + random_crop=dataset_config.get("random_crop", True), + force_channels=force_channels + ) + + + return torch.utils.data.DataLoader(valid_set, batch_size, shuffle=False, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) + + elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility + wds_configs = [] + + for wds_config in dataset_config["datasets"]: + + custom_metadata_fn = None + custom_metadata_module_path = wds_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + if "s3_path" in wds_config: + + wds_configs.append( + S3DatasetConfig( + id=wds_config["id"], + s3_path=wds_config["s3_path"], + custom_metadata_fn=custom_metadata_fn, + profile=wds_config.get("profile", None), + ) + ) + + elif "path" in wds_config: + + wds_configs.append( + LocalWebDatasetConfig( + id=wds_config["id"], + path=wds_config["path"], + custom_metadata_fn=custom_metadata_fn + ) + ) + + return WebDatasetDataLoader( + wds_configs, + sample_rate=sample_rate, + sample_size=sample_size, + batch_size=batch_size, + random_crop=dataset_config.get("random_crop", True), + num_workers=num_workers, + persistent_workers=True, + force_channels=force_channels, + epoch_steps=dataset_config.get("epoch_steps", 2000) + ).data_loader + \ No newline at end of file diff --git a/stable_audio_tools/data/utils.py b/stable_audio_tools/data/utils.py new file mode 100644 index 0000000..be6e09b --- /dev/null +++ b/stable_audio_tools/data/utils.py @@ -0,0 +1,199 @@ +import math +import random +import torch + +from torch import nn +from typing import Tuple +import os +import subprocess as sp +from PIL import Image +from torchvision import transforms +from decord import VideoReader, cpu + +class PadCrop(nn.Module): + def __init__(self, n_samples, randomize=True): + super().__init__() + self.n_samples = n_samples + self.randomize = randomize + + def __call__(self, signal): + n, s = signal.shape + start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() + end = start + self.n_samples + output = signal.new_zeros([n, self.n_samples]) + output[:, :min(s, self.n_samples)] = signal[:, start:end] + return output + + +class PadCrop_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + super().__init__() + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int, torch.Tensor]: + n_channels, n_samples = source.shape + + # Calculate the duration of the audio in seconds + total_duration = n_samples // self.sample_rate + + # If the audio is shorter than the desired length, pad it + upper_bound = max(0, n_samples - self.n_samples) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + + if self.randomize and n_samples > self.n_samples: + valid_offsets = [ + i * self.sample_rate for i in range(0, total_duration, 10) + if i * self.sample_rate + self.n_samples <= n_samples and + (total_duration <= 20 or total_duration - i >= 15) + ] + if valid_offsets: + offset = random.choice(valid_offsets) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + + # Create the chunk + chunk = source.new_zeros([n_channels, self.n_samples]) + + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] + + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + + +class PhaseFlipper(nn.Module): + "Randomly invert the phase of a signal" + def __init__(self, p=0.5): + super().__init__() + self.p = p + def __call__(self, signal): + return -signal if (random.random() < self.p) else signal + +class Mono(nn.Module): + def __call__(self, signal): + return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal + +class Stereo(nn.Module): + def __call__(self, signal): + signal_shape = signal.shape + # Check if it's mono + if len(signal_shape) == 1: # s -> 2, s + signal = signal.unsqueeze(0).repeat(2, 1) + elif len(signal_shape) == 2: + if signal_shape[0] == 1: #1, s -> 2, s + signal = signal.repeat(2, 1) + elif signal_shape[0] > 2: #?, s -> 2,s + signal = signal[:2, :] + + return signal + + +def adjust_video_duration(video_tensor, duration, target_fps): + current_duration = video_tensor.shape[0] + target_duration = duration * target_fps + if current_duration > target_duration: + video_tensor = video_tensor[:target_duration] + elif current_duration < target_duration: + last_frame = video_tensor[-1:] + repeat_times = target_duration - current_duration + video_tensor = torch.cat((video_tensor, last_frame.repeat(repeat_times, 1, 1, 1)), dim=0) + return video_tensor + +def read_video(filepath, seek_time=0., duration=-1, target_fps=2): + if filepath is None: + return torch.zeros((int(duration * target_fps), 3, 224, 224)) + + ext = os.path.splitext(filepath)[1].lower() + if ext in ['.jpg', '.jpeg', '.png']: + resize_transform = transforms.Resize((224, 224)) + image = Image.open(filepath).convert("RGB") + frame = transforms.ToTensor()(image).unsqueeze(0) + frame = resize_transform(frame) + target_frames = int(duration * target_fps) + frame = frame.repeat(int(math.ceil(target_frames / frame.shape[0])), 1, 1, 1)[:target_frames] + assert frame.shape[0] == target_frames, f"The shape of frame is {frame.shape}" + return frame + + vr = VideoReader(filepath, ctx=cpu(0)) + fps = vr.get_avg_fps() + total_frames = len(vr) + + seek_frame = int(seek_time * fps) + if duration > 0: + total_frames_to_read = int(target_fps * duration) + frame_interval = int(math.ceil(fps / target_fps)) + end_frame = min(seek_frame + total_frames_to_read * frame_interval, total_frames) + frame_ids = list(range(seek_frame, end_frame, frame_interval)) + else: + frame_interval = int(math.ceil(fps / target_fps)) + frame_ids = list(range(0, total_frames, frame_interval)) + + frames = vr.get_batch(frame_ids).asnumpy() + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) + + if frames.shape[2] != 224 or frames.shape[3] != 224: + resize_transform = transforms.Resize((224, 224)) + frames = resize_transform(frames) + + video_tensor = adjust_video_duration(frames, duration, target_fps) + assert video_tensor.shape[0] == duration * target_fps, f"The shape of video_tensor is {video_tensor.shape}" + return video_tensor + +def merge_video_audio(video_path, audio_path, output_path, start_time, duration): + command = [ + 'ffmpeg', + '-y', + '-ss', str(start_time), + '-t', str(duration), + '-i', video_path, + '-i', audio_path, + '-c:v', 'copy', + '-c:a', 'aac', + '-map', '0:v:0', + '-map', '1:a:0', + '-shortest', + '-strict', 'experimental', + output_path + ] + + try: + sp.run(command, check=True) + print(f"Successfully merged audio and video into {output_path}") + return output_path + except sp.CalledProcessError as e: + print(f"Error merging audio and video: {e}") + return None + +def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total): + if audio_path is None: + return torch.zeros((2, int(sample_rate * seconds_total))) + audio_tensor, sr = torchaudio.load(audio_path) + start_index = int(sample_rate * seconds_start) + target_length = int(sample_rate * seconds_total) + end_index = start_index + target_length + audio_tensor = audio_tensor[:, start_index:end_index] + if audio_tensor.shape[1] < target_length: + pad_length = target_length - audio_tensor.shape[1] + audio_tensor = F.pad(audio_tensor, (pad_length, 0)) + return audio_tensor \ No newline at end of file