AudioX
This commit is contained in:
parent
1bf5691a38
commit
0251bea97a
4 changed files with 1075 additions and 1 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -173,6 +173,5 @@ logs/
|
||||||
log/
|
log/
|
||||||
saved_ckpt/
|
saved_ckpt/
|
||||||
wandb/
|
wandb/
|
||||||
data/
|
|
||||||
demo_result/
|
demo_result/
|
||||||
model/
|
model/
|
||||||
0
stable_audio_tools/data/__init__.py
Normal file
0
stable_audio_tools/data/__init__.py
Normal file
876
stable_audio_tools/data/dataset.py
Normal file
876
stable_audio_tools/data/dataset.py
Normal file
|
|
@ -0,0 +1,876 @@
|
||||||
|
import importlib
|
||||||
|
import numpy as np
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import posixpath
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import webdataset as wds
|
||||||
|
|
||||||
|
from aeiou.core import is_silence
|
||||||
|
from os import path
|
||||||
|
from pedalboard.io import AudioFile
|
||||||
|
from torchaudio import transforms as T
|
||||||
|
from typing import Optional, Callable, List
|
||||||
|
from torchdata.datapipes.iter import IterDataPipe, IterableWrapper
|
||||||
|
from torchdata.datapipes.iter import Prefetcher
|
||||||
|
|
||||||
|
from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
from memory_profiler import profile
|
||||||
|
|
||||||
|
|
||||||
|
AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
|
||||||
|
|
||||||
|
# fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
|
||||||
|
|
||||||
|
def fast_scandir(
|
||||||
|
dir:str, # top-level directory at which to begin scanning
|
||||||
|
ext:list, # list of allowed file extensions,
|
||||||
|
#max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB
|
||||||
|
):
|
||||||
|
"very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
|
||||||
|
subfolders, files = [], []
|
||||||
|
ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed
|
||||||
|
try: # hope to avoid 'permission denied' by this try
|
||||||
|
for f in os.scandir(dir):
|
||||||
|
try: # 'hope to avoid too many levels of symbolic links' error
|
||||||
|
if f.is_dir():
|
||||||
|
subfolders.append(f.path)
|
||||||
|
elif f.is_file():
|
||||||
|
file_ext = os.path.splitext(f.name)[1].lower()
|
||||||
|
is_hidden = os.path.basename(f.path).startswith(".")
|
||||||
|
|
||||||
|
if file_ext in ext and not is_hidden:
|
||||||
|
files.append(f.path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for dir in list(subfolders):
|
||||||
|
sf, f = fast_scandir(dir, ext)
|
||||||
|
subfolders.extend(sf)
|
||||||
|
files.extend(f)
|
||||||
|
return subfolders, files
|
||||||
|
|
||||||
|
def extract_audio_paths(jsonl_file, exts):
|
||||||
|
audio_paths = []
|
||||||
|
video_paths = []
|
||||||
|
text_prompts = []
|
||||||
|
data_types = []
|
||||||
|
with open(jsonl_file, 'r') as file:
|
||||||
|
for line in file:
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
path = data.get('path', '')
|
||||||
|
video_path = data.get('video_path', '')
|
||||||
|
text_prompt = data.get('caption', '')
|
||||||
|
data_type = data.get('type', None)
|
||||||
|
if any(path.endswith(ext) for ext in exts):
|
||||||
|
audio_paths.append(path)
|
||||||
|
video_paths.append(video_path)
|
||||||
|
text_prompts.append(text_prompt)
|
||||||
|
data_types.append(data_type)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"Error decoding JSON line: {line.strip()}")
|
||||||
|
return audio_paths, video_paths, text_prompts, data_types
|
||||||
|
|
||||||
|
def keyword_scandir(
|
||||||
|
dir: str, # top-level directory at which to begin scanning
|
||||||
|
ext: list, # list of allowed file extensions
|
||||||
|
keywords: list, # list of keywords to search for in the file name
|
||||||
|
):
|
||||||
|
"very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
|
||||||
|
subfolders, files = [], []
|
||||||
|
# make keywords case insensitive
|
||||||
|
keywords = [keyword.lower() for keyword in keywords]
|
||||||
|
# add starting period to extensions if needed
|
||||||
|
ext = ['.'+x if x[0] != '.' else x for x in ext]
|
||||||
|
banned_words = ["paxheader", "__macosx"]
|
||||||
|
try: # hope to avoid 'permission denied' by this try
|
||||||
|
for f in os.scandir(dir):
|
||||||
|
try: # 'hope to avoid too many levels of symbolic links' error
|
||||||
|
if f.is_dir():
|
||||||
|
subfolders.append(f.path)
|
||||||
|
elif f.is_file():
|
||||||
|
is_hidden = f.name.split("/")[-1][0] == '.'
|
||||||
|
has_ext = os.path.splitext(f.name)[1].lower() in ext
|
||||||
|
name_lower = f.name.lower()
|
||||||
|
has_keyword = any(
|
||||||
|
[keyword in name_lower for keyword in keywords])
|
||||||
|
has_banned = any(
|
||||||
|
[banned_word in name_lower for banned_word in banned_words])
|
||||||
|
if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
|
||||||
|
files.append(f.path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for dir in list(subfolders):
|
||||||
|
sf, f = keyword_scandir(dir, ext, keywords)
|
||||||
|
subfolders.extend(sf)
|
||||||
|
files.extend(f)
|
||||||
|
return subfolders, files
|
||||||
|
|
||||||
|
def get_audio_filenames(
|
||||||
|
paths: list, # directories in which to search
|
||||||
|
keywords=None,
|
||||||
|
exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
|
||||||
|
):
|
||||||
|
|
||||||
|
"recursively get a list of audio filenames"
|
||||||
|
filenames = []
|
||||||
|
video_filenames = []
|
||||||
|
text_prompts = []
|
||||||
|
data_types = []
|
||||||
|
|
||||||
|
if type(paths) is str:
|
||||||
|
paths = [paths]
|
||||||
|
|
||||||
|
|
||||||
|
if os.path.isdir(paths[0]):
|
||||||
|
for path in paths: # get a list of relevant filenames
|
||||||
|
if keywords is not None:
|
||||||
|
subfolders, files = keyword_scandir(path, exts, keywords)
|
||||||
|
else:
|
||||||
|
subfolders, files = fast_scandir(path, exts)
|
||||||
|
filenames.extend(files)
|
||||||
|
return filenames
|
||||||
|
|
||||||
|
elif os.path.isfile(paths[0]):
|
||||||
|
assert paths[0].endswith('.jsonl')
|
||||||
|
for path in paths:
|
||||||
|
audio_paths, video_paths, text_prompt, data_type = extract_audio_paths(path, exts)
|
||||||
|
filenames.extend(audio_paths)
|
||||||
|
video_filenames.extend(video_paths)
|
||||||
|
text_prompts.extend(text_prompt)
|
||||||
|
data_types.extend(data_type)
|
||||||
|
|
||||||
|
return filenames, video_filenames, text_prompts, data_types
|
||||||
|
|
||||||
|
|
||||||
|
class LocalDatasetConfig:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
path: str,
|
||||||
|
video_fps: int,
|
||||||
|
custom_metadata_fn: Optional[Callable[[str], str]] = None
|
||||||
|
):
|
||||||
|
self.id = id
|
||||||
|
self.path = path
|
||||||
|
self.video_fps = video_fps
|
||||||
|
self.custom_metadata_fn = custom_metadata_fn
|
||||||
|
|
||||||
|
|
||||||
|
# @profile
|
||||||
|
class SampleDataset(torch.utils.data.Dataset):
|
||||||
|
# @profile
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configs,
|
||||||
|
sample_size=65536,
|
||||||
|
sample_rate=48000,
|
||||||
|
keywords=None,
|
||||||
|
random_crop=True,
|
||||||
|
force_channels="stereo",
|
||||||
|
video_fps=5
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.filenames = []
|
||||||
|
self.video_filenames = []
|
||||||
|
self.text_prompts = []
|
||||||
|
self.data_types = []
|
||||||
|
|
||||||
|
self.augs = torch.nn.Sequential(
|
||||||
|
PhaseFlipper(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.root_paths = []
|
||||||
|
|
||||||
|
self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
|
||||||
|
|
||||||
|
self.force_channels = force_channels
|
||||||
|
|
||||||
|
self.encoding = torch.nn.Sequential(
|
||||||
|
Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
|
||||||
|
Mono() if self.force_channels == "mono" else torch.nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sr = sample_rate
|
||||||
|
|
||||||
|
self.custom_metadata_fns = {}
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
self.video_fps = config.video_fps
|
||||||
|
|
||||||
|
self.root_paths.append(config.path)
|
||||||
|
audio_files, video_files, text_prompt, data_types = get_audio_filenames(config.path, keywords)
|
||||||
|
|
||||||
|
self.filenames.extend(audio_files)
|
||||||
|
self.video_filenames.extend(video_files)
|
||||||
|
self.text_prompts.extend(text_prompt)
|
||||||
|
self.data_types.extend(data_types)
|
||||||
|
if config.custom_metadata_fn is not None:
|
||||||
|
self.custom_metadata_fns[config.path] = config.custom_metadata_fn
|
||||||
|
|
||||||
|
print(f'Found {len(self.filenames)} files')
|
||||||
|
|
||||||
|
|
||||||
|
def load_file(self, filename):
|
||||||
|
ext = filename.split(".")[-1]
|
||||||
|
|
||||||
|
if ext == "mp3":
|
||||||
|
with AudioFile(filename) as f:
|
||||||
|
audio = f.read(f.frames)
|
||||||
|
audio = torch.from_numpy(audio)
|
||||||
|
in_sr = f.samplerate
|
||||||
|
else:
|
||||||
|
audio, in_sr = torchaudio.load(filename, format=ext)
|
||||||
|
|
||||||
|
if in_sr != self.sr:
|
||||||
|
resample_tf = T.Resample(in_sr, self.sr)
|
||||||
|
audio = resample_tf(audio)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.filenames)
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
audio_filename = self.filenames[idx]
|
||||||
|
video_filename = self.video_filenames[idx]
|
||||||
|
text_prompt = self.text_prompts[idx]
|
||||||
|
data_type = self.data_types[idx]
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
audio = self.load_file(audio_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if data_type in ["text_condition-audio", "text_condition-music",
|
||||||
|
"video_condition-audio", "video_condition-music",
|
||||||
|
"text+video_condition-audio","text+video_condition-music"]:
|
||||||
|
if_audio_contition = False
|
||||||
|
audio_prompt = torch.zeros((2, self.sr * 10))
|
||||||
|
elif data_type in ["audio_condition-audio", "audio_condition-music",
|
||||||
|
"uni_condition-audio", "uni_condition-music"]:
|
||||||
|
if_audio_contition = True
|
||||||
|
|
||||||
|
if if_audio_contition:
|
||||||
|
audio_org = audio.clamp(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
|
||||||
|
|
||||||
|
if self.augs is not None:
|
||||||
|
audio = self.augs(audio)
|
||||||
|
|
||||||
|
audio = audio.clamp(-1, 1)
|
||||||
|
|
||||||
|
if if_audio_contition:
|
||||||
|
if data_type.split("-")[-1] == "audio":
|
||||||
|
start_index = max(0, int((seconds_start) * self.sr))
|
||||||
|
end_index = int((seconds_start+10) * self.sr)
|
||||||
|
audio_prompt = audio_org[:, start_index:end_index]
|
||||||
|
|
||||||
|
elif data_type.split("-")[-1] == "music":
|
||||||
|
if seconds_start < 10:
|
||||||
|
start_index = 0
|
||||||
|
end_index = int(10 * self.sr)
|
||||||
|
else:
|
||||||
|
start_index = max(0, int((seconds_start - 10) * self.sr))
|
||||||
|
end_index = int(seconds_start * self.sr)
|
||||||
|
audio_prompt = audio_org[:, start_index:end_index]
|
||||||
|
|
||||||
|
# Encode the file to assist in prediction
|
||||||
|
if self.encoding is not None:
|
||||||
|
audio = self.encoding(audio)
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
|
||||||
|
info["path"] = audio_filename
|
||||||
|
info["video_path"] = video_filename
|
||||||
|
info["text_prompt"] = text_prompt
|
||||||
|
info["audio_prompt"] = audio_prompt
|
||||||
|
info["data_type"] = data_type
|
||||||
|
|
||||||
|
for root_path in self.root_paths:
|
||||||
|
if root_path in audio_filename:
|
||||||
|
info["relpath"] = path.relpath(audio_filename, root_path)
|
||||||
|
|
||||||
|
info["timestamps"] = (t_start, t_end)
|
||||||
|
info["seconds_start"] = seconds_start
|
||||||
|
info["seconds_total"] = seconds_total
|
||||||
|
info["padding_mask"] = padding_mask
|
||||||
|
info["video_fps"] = self.video_fps
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
info["load_time"] = end_time - start_time
|
||||||
|
|
||||||
|
for custom_md_path in self.custom_metadata_fns.keys():
|
||||||
|
if os.path.isdir(custom_md_path):
|
||||||
|
if custom_md_path in audio_filename:
|
||||||
|
custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
|
||||||
|
custom_metadata = custom_metadata_fn(info, audio)
|
||||||
|
info.update(custom_metadata)
|
||||||
|
elif os.path.isfile(custom_md_path):
|
||||||
|
custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
|
||||||
|
custom_metadata = custom_metadata_fn(info, audio)
|
||||||
|
info.update(custom_metadata)
|
||||||
|
|
||||||
|
if "__reject__" in info and info["__reject__"]:
|
||||||
|
return self[random.randrange(len(self))]
|
||||||
|
|
||||||
|
file_name = audio_filename.split('/')[-1]
|
||||||
|
|
||||||
|
return (audio, info)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Couldn\'t load file {audio_filename}: {e}')
|
||||||
|
return self[random.randrange(len(self))]
|
||||||
|
|
||||||
|
def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
|
||||||
|
"""Return function over iterator that groups key, value pairs into samples.
|
||||||
|
:param keys: function that splits the key into key and extension (base_plus_ext)
|
||||||
|
:param lcase: convert suffixes to lower case (Default value = True)
|
||||||
|
"""
|
||||||
|
current_sample = None
|
||||||
|
for filesample in data:
|
||||||
|
assert isinstance(filesample, dict)
|
||||||
|
fname, value = filesample["fname"], filesample["data"]
|
||||||
|
prefix, suffix = keys(fname)
|
||||||
|
if wds.tariterators.trace:
|
||||||
|
print(
|
||||||
|
prefix,
|
||||||
|
suffix,
|
||||||
|
current_sample.keys() if isinstance(current_sample, dict) else None,
|
||||||
|
)
|
||||||
|
if prefix is None:
|
||||||
|
continue
|
||||||
|
if lcase:
|
||||||
|
suffix = suffix.lower()
|
||||||
|
if current_sample is None or prefix != current_sample["__key__"]:
|
||||||
|
if wds.tariterators.valid_sample(current_sample):
|
||||||
|
yield current_sample
|
||||||
|
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
||||||
|
if suffix in current_sample:
|
||||||
|
print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
|
||||||
|
if suffixes is None or suffix in suffixes:
|
||||||
|
current_sample[suffix] = value
|
||||||
|
if wds.tariterators.valid_sample(current_sample):
|
||||||
|
yield current_sample
|
||||||
|
|
||||||
|
wds.tariterators.group_by_keys = group_by_keys
|
||||||
|
|
||||||
|
# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
|
||||||
|
|
||||||
|
def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
|
||||||
|
"""
|
||||||
|
Returns a list of full S3 paths to files in a given S3 bucket and directory path.
|
||||||
|
"""
|
||||||
|
# Ensure dataset_path ends with a trailing slash
|
||||||
|
if dataset_path != '' and not dataset_path.endswith('/'):
|
||||||
|
dataset_path += '/'
|
||||||
|
# Use posixpath to construct the S3 URL path
|
||||||
|
bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
|
||||||
|
# Construct the `aws s3 ls` command
|
||||||
|
cmd = ['aws', 's3', 'ls', bucket_path]
|
||||||
|
|
||||||
|
if profile is not None:
|
||||||
|
cmd.extend(['--profile', profile])
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
# Add the --recursive flag if requested
|
||||||
|
cmd.append('--recursive')
|
||||||
|
|
||||||
|
# Run the `aws s3 ls` command and capture the output
|
||||||
|
run_ls = subprocess.run(cmd, capture_output=True, check=True)
|
||||||
|
# Split the output into lines and strip whitespace from each line
|
||||||
|
contents = run_ls.stdout.decode('utf-8').split('\n')
|
||||||
|
contents = [x.strip() for x in contents if x]
|
||||||
|
# Remove the timestamp from lines that begin with a timestamp
|
||||||
|
contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
|
||||||
|
if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
|
||||||
|
# Construct a full S3 path for each file in the contents list
|
||||||
|
contents = [posixpath.join(s3_url_prefix or '', x)
|
||||||
|
for x in contents if not x.endswith('/')]
|
||||||
|
# Apply the filter, if specified
|
||||||
|
if filter:
|
||||||
|
contents = [x for x in contents if filter in x]
|
||||||
|
# Remove redundant directory names in the S3 URL
|
||||||
|
if recursive:
|
||||||
|
# Get the main directory name from the S3 URL
|
||||||
|
main_dir = "/".join(bucket_path.split('/')[3:])
|
||||||
|
# Remove the redundant directory names from each file path
|
||||||
|
contents = [x.replace(f'{main_dir}', '').replace(
|
||||||
|
'//', '/') for x in contents]
|
||||||
|
# Print debugging information, if requested
|
||||||
|
if debug:
|
||||||
|
print("contents = \n", contents)
|
||||||
|
# Return the list of S3 paths to files
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_s3_urls(
|
||||||
|
names=[], # list of all valid [LAION AudioDataset] dataset names
|
||||||
|
# list of subsets you want from those datasets, e.g. ['train','valid']
|
||||||
|
subsets=[''],
|
||||||
|
s3_url_prefix=None, # prefix for those dataset names
|
||||||
|
recursive=True, # recursively list all tar files in all subdirs
|
||||||
|
filter_str='tar', # only grab files with this substring
|
||||||
|
# print debugging info -- note: info displayed likely to change at dev's whims
|
||||||
|
debug=False,
|
||||||
|
profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
|
||||||
|
):
|
||||||
|
"get urls of shards (tar files) for multiple datasets in one s3 bucket"
|
||||||
|
urls = []
|
||||||
|
for name in names:
|
||||||
|
# If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
|
||||||
|
if s3_url_prefix is None:
|
||||||
|
contents_str = name
|
||||||
|
else:
|
||||||
|
# Construct the S3 path using the s3_url_prefix and the current name value
|
||||||
|
contents_str = posixpath.join(s3_url_prefix, name)
|
||||||
|
if debug:
|
||||||
|
print(f"get_all_s3_urls: {contents_str}:")
|
||||||
|
for subset in subsets:
|
||||||
|
subset_str = posixpath.join(contents_str, subset)
|
||||||
|
if debug:
|
||||||
|
print(f"subset_str = {subset_str}")
|
||||||
|
# Get the list of tar files in the current subset directory
|
||||||
|
profile = profiles.get(name, None)
|
||||||
|
tar_list = get_s3_contents(
|
||||||
|
subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
|
||||||
|
for tar in tar_list:
|
||||||
|
# Escape spaces and parentheses in the tar filename for use in the shell command
|
||||||
|
tar = tar.replace(" ", "\ ").replace(
|
||||||
|
"(", "\(").replace(")", "\)")
|
||||||
|
# Construct the S3 path to the current tar file
|
||||||
|
s3_path = posixpath.join(name, subset, tar) + " -"
|
||||||
|
# Construct the AWS CLI command to download the current tar file
|
||||||
|
if s3_url_prefix is None:
|
||||||
|
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
|
||||||
|
else:
|
||||||
|
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
|
||||||
|
if profiles.get(name):
|
||||||
|
request_str += f" --profile {profiles.get(name)}"
|
||||||
|
if debug:
|
||||||
|
print("request_str = ", request_str)
|
||||||
|
# Add the constructed URL to the list of URLs
|
||||||
|
urls.append(request_str)
|
||||||
|
return urls
|
||||||
|
|
||||||
|
|
||||||
|
def log_and_continue(exn):
|
||||||
|
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
||||||
|
print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_sample(sample):
|
||||||
|
has_json = "json" in sample
|
||||||
|
has_audio = "audio" in sample
|
||||||
|
is_silent = is_silence(sample["audio"])
|
||||||
|
is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
|
||||||
|
|
||||||
|
return has_json and has_audio and not is_silent and not is_rejected
|
||||||
|
|
||||||
|
class S3DatasetConfig:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
s3_path: str,
|
||||||
|
custom_metadata_fn: Optional[Callable[[str], str]] = None,
|
||||||
|
profile: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.id = id
|
||||||
|
self.path = s3_path
|
||||||
|
self.custom_metadata_fn = custom_metadata_fn
|
||||||
|
self.profile = profile
|
||||||
|
self.urls = []
|
||||||
|
|
||||||
|
def load_data_urls(self):
|
||||||
|
self.urls = get_all_s3_urls(
|
||||||
|
names=[self.path],
|
||||||
|
s3_url_prefix=None,
|
||||||
|
recursive=True,
|
||||||
|
profiles={self.path: self.profile} if self.profile else {},
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.urls
|
||||||
|
|
||||||
|
class LocalWebDatasetConfig:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
path: str,
|
||||||
|
custom_metadata_fn: Optional[Callable[[str], str]] = None,
|
||||||
|
profile: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.id = id
|
||||||
|
self.path = path
|
||||||
|
self.custom_metadata_fn = custom_metadata_fn
|
||||||
|
self.urls = []
|
||||||
|
|
||||||
|
def load_data_urls(self):
|
||||||
|
|
||||||
|
self.urls = fast_scandir(self.path, ["tar"])[1]
|
||||||
|
|
||||||
|
return self.urls
|
||||||
|
|
||||||
|
def audio_decoder(key, value):
|
||||||
|
# Get file extension from key
|
||||||
|
ext = key.split(".")[-1]
|
||||||
|
|
||||||
|
if ext in AUDIO_KEYS:
|
||||||
|
return torchaudio.load(io.BytesIO(value))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def collation_fn(samples):
|
||||||
|
batched = list(zip(*samples))
|
||||||
|
result = []
|
||||||
|
for b in batched:
|
||||||
|
if isinstance(b[0], (int, float)):
|
||||||
|
b = np.array(b)
|
||||||
|
elif isinstance(b[0], torch.Tensor):
|
||||||
|
b = torch.stack(b)
|
||||||
|
elif isinstance(b[0], np.ndarray):
|
||||||
|
b = np.array(b)
|
||||||
|
else:
|
||||||
|
b = b
|
||||||
|
result.append(b)
|
||||||
|
return result
|
||||||
|
|
||||||
|
class WebDatasetDataLoader():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
datasets: List[S3DatasetConfig],
|
||||||
|
batch_size,
|
||||||
|
sample_size,
|
||||||
|
sample_rate=48000,
|
||||||
|
num_workers=8,
|
||||||
|
epoch_steps=1000,
|
||||||
|
random_crop=True,
|
||||||
|
force_channels="stereo",
|
||||||
|
augment_phase=True,
|
||||||
|
**data_loader_kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
self.datasets = datasets
|
||||||
|
|
||||||
|
self.sample_size = sample_size
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.random_crop = random_crop
|
||||||
|
self.force_channels = force_channels
|
||||||
|
self.augment_phase = augment_phase
|
||||||
|
|
||||||
|
urls = [dataset.load_data_urls() for dataset in datasets]
|
||||||
|
|
||||||
|
# Flatten the list of lists of URLs
|
||||||
|
urls = [url for dataset_urls in urls for url in dataset_urls]
|
||||||
|
|
||||||
|
# Shuffle the urls
|
||||||
|
random.shuffle(urls)
|
||||||
|
|
||||||
|
self.dataset = wds.DataPipeline(
|
||||||
|
wds.ResampledShards(urls),
|
||||||
|
wds.tarfile_to_samples(handler=log_and_continue),
|
||||||
|
wds.decode(audio_decoder, handler=log_and_continue),
|
||||||
|
wds.map(self.wds_preprocess, handler=log_and_continue),
|
||||||
|
wds.select(is_valid_sample),
|
||||||
|
wds.to_tuple("audio", "json", handler=log_and_continue),
|
||||||
|
#wds.shuffle(bufsize=1000, initial=5000),
|
||||||
|
wds.batched(batch_size, partial=False, collation_fn=collation_fn),
|
||||||
|
).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps)
|
||||||
|
|
||||||
|
self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs)
|
||||||
|
|
||||||
|
def wds_preprocess(self, sample):
|
||||||
|
|
||||||
|
found_key, rewrite_key = '', ''
|
||||||
|
for k, v in sample.items(): # print the all entries in dict
|
||||||
|
for akey in AUDIO_KEYS:
|
||||||
|
if k.endswith(akey):
|
||||||
|
# to rename long/weird key with its simpler counterpart
|
||||||
|
found_key, rewrite_key = k, akey
|
||||||
|
break
|
||||||
|
if '' != found_key:
|
||||||
|
break
|
||||||
|
if '' == found_key: # got no audio!
|
||||||
|
return None # try returning None to tell WebDataset to skip this one
|
||||||
|
|
||||||
|
audio, in_sr = sample[found_key]
|
||||||
|
if in_sr != self.sample_rate:
|
||||||
|
resample_tf = T.Resample(in_sr, self.sample_rate)
|
||||||
|
audio = resample_tf(audio)
|
||||||
|
|
||||||
|
if self.sample_size is not None:
|
||||||
|
# Pad/crop and get the relative timestamp
|
||||||
|
pad_crop = PadCrop_Normalized_T(
|
||||||
|
self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
|
||||||
|
audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
|
||||||
|
audio)
|
||||||
|
sample["json"]["seconds_start"] = seconds_start
|
||||||
|
sample["json"]["seconds_total"] = seconds_total
|
||||||
|
sample["json"]["padding_mask"] = padding_mask
|
||||||
|
else:
|
||||||
|
t_start, t_end = 0, 1
|
||||||
|
|
||||||
|
# Check if audio is length zero, initialize to a single zero if so
|
||||||
|
if audio.shape[-1] == 0:
|
||||||
|
audio = torch.zeros(1, 1)
|
||||||
|
|
||||||
|
# Make the audio stereo and augment by randomly inverting phase
|
||||||
|
augs = torch.nn.Sequential(
|
||||||
|
Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
|
||||||
|
Mono() if self.force_channels == "mono" else torch.nn.Identity(),
|
||||||
|
PhaseFlipper() if self.augment_phase else torch.nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
audio = augs(audio)
|
||||||
|
|
||||||
|
sample["json"]["timestamps"] = (t_start, t_end)
|
||||||
|
|
||||||
|
if "text" in sample["json"]:
|
||||||
|
sample["json"]["prompt"] = sample["json"]["text"]
|
||||||
|
|
||||||
|
# Check for custom metadata functions
|
||||||
|
for dataset in self.datasets:
|
||||||
|
if dataset.custom_metadata_fn is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if dataset.path in sample["__url__"]:
|
||||||
|
custom_metadata = dataset.custom_metadata_fn(sample["json"], audio)
|
||||||
|
sample["json"].update(custom_metadata)
|
||||||
|
|
||||||
|
if found_key != rewrite_key: # rename long/weird key with its simpler counterpart
|
||||||
|
del sample[found_key]
|
||||||
|
|
||||||
|
sample["audio"] = audio
|
||||||
|
|
||||||
|
# Add audio to the metadata as well for conditioning
|
||||||
|
sample["json"]["audio"] = audio
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4, video_fps=5):
|
||||||
|
|
||||||
|
dataset_type = dataset_config.get("dataset_type", None)
|
||||||
|
|
||||||
|
assert dataset_type is not None, "Dataset type must be specified in dataset config"
|
||||||
|
|
||||||
|
if audio_channels == 1:
|
||||||
|
force_channels = "mono"
|
||||||
|
else:
|
||||||
|
force_channels = "stereo"
|
||||||
|
|
||||||
|
if dataset_type == "audio_dir":
|
||||||
|
|
||||||
|
audio_dir_configs = dataset_config.get("datasets", None)
|
||||||
|
|
||||||
|
assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
|
||||||
|
for audio_dir_config in audio_dir_configs:
|
||||||
|
audio_dir_path = audio_dir_config.get("path", None)
|
||||||
|
assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
|
||||||
|
|
||||||
|
custom_metadata_fn = None
|
||||||
|
custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
|
||||||
|
|
||||||
|
if custom_metadata_module_path is not None:
|
||||||
|
spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
|
||||||
|
metadata_module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(metadata_module)
|
||||||
|
|
||||||
|
custom_metadata_fn = metadata_module.get_custom_metadata
|
||||||
|
|
||||||
|
configs.append(
|
||||||
|
LocalDatasetConfig(
|
||||||
|
id=audio_dir_config["id"],
|
||||||
|
path=audio_dir_path,
|
||||||
|
custom_metadata_fn=custom_metadata_fn,
|
||||||
|
video_fps=video_fps
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
train_set = SampleDataset(
|
||||||
|
configs,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_size=sample_size,
|
||||||
|
random_crop=dataset_config.get("random_crop", True),
|
||||||
|
force_channels=force_channels,
|
||||||
|
video_fps=video_fps
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
|
||||||
|
num_workers=num_workers, persistent_workers=True, pin_memory=False, drop_last=True, collate_fn=collation_fn)
|
||||||
|
|
||||||
|
elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
|
||||||
|
wds_configs = []
|
||||||
|
|
||||||
|
for wds_config in dataset_config["datasets"]:
|
||||||
|
|
||||||
|
custom_metadata_fn = None
|
||||||
|
custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
|
||||||
|
|
||||||
|
if custom_metadata_module_path is not None:
|
||||||
|
spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
|
||||||
|
metadata_module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(metadata_module)
|
||||||
|
|
||||||
|
custom_metadata_fn = metadata_module.get_custom_metadata
|
||||||
|
|
||||||
|
if "s3_path" in wds_config:
|
||||||
|
|
||||||
|
wds_configs.append(
|
||||||
|
S3DatasetConfig(
|
||||||
|
id=wds_config["id"],
|
||||||
|
s3_path=wds_config["s3_path"],
|
||||||
|
custom_metadata_fn=custom_metadata_fn,
|
||||||
|
profile=wds_config.get("profile", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif "path" in wds_config:
|
||||||
|
|
||||||
|
wds_configs.append(
|
||||||
|
LocalWebDatasetConfig(
|
||||||
|
id=wds_config["id"],
|
||||||
|
path=wds_config["path"],
|
||||||
|
custom_metadata_fn=custom_metadata_fn
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return WebDatasetDataLoader(
|
||||||
|
wds_configs,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_size=sample_size,
|
||||||
|
batch_size=batch_size,
|
||||||
|
random_crop=dataset_config.get("random_crop", True),
|
||||||
|
num_workers=num_workers,
|
||||||
|
persistent_workers=True,
|
||||||
|
force_channels=force_channels,
|
||||||
|
epoch_steps=dataset_config.get("epoch_steps", 2000)
|
||||||
|
).data_loader
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloader_from_config_valid(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
|
||||||
|
|
||||||
|
|
||||||
|
dataset_type = dataset_config.get("dataset_type", None)
|
||||||
|
|
||||||
|
assert dataset_type is not None, "Dataset type must be specified in dataset config"
|
||||||
|
|
||||||
|
if audio_channels == 1:
|
||||||
|
force_channels = "mono"
|
||||||
|
else:
|
||||||
|
force_channels = "stereo"
|
||||||
|
|
||||||
|
if dataset_type == "audio_dir":
|
||||||
|
|
||||||
|
audio_dir_configs = dataset_config.get("datasets", None)
|
||||||
|
|
||||||
|
assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
|
||||||
|
for audio_dir_config in audio_dir_configs:
|
||||||
|
audio_dir_path = audio_dir_config.get("path", None)
|
||||||
|
assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
|
||||||
|
|
||||||
|
custom_metadata_fn = None
|
||||||
|
custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
|
||||||
|
|
||||||
|
if custom_metadata_module_path is not None:
|
||||||
|
spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
|
||||||
|
metadata_module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(metadata_module)
|
||||||
|
|
||||||
|
custom_metadata_fn = metadata_module.get_custom_metadata
|
||||||
|
|
||||||
|
configs.append(
|
||||||
|
LocalDatasetConfig(
|
||||||
|
id=audio_dir_config["id"],
|
||||||
|
path=audio_dir_path,
|
||||||
|
custom_metadata_fn=custom_metadata_fn
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_set = SampleDataset(
|
||||||
|
configs,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_size=sample_size,
|
||||||
|
random_crop=dataset_config.get("random_crop", True),
|
||||||
|
force_channels=force_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
return torch.utils.data.DataLoader(valid_set, batch_size, shuffle=False,
|
||||||
|
num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
|
||||||
|
|
||||||
|
elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
|
||||||
|
wds_configs = []
|
||||||
|
|
||||||
|
for wds_config in dataset_config["datasets"]:
|
||||||
|
|
||||||
|
custom_metadata_fn = None
|
||||||
|
custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
|
||||||
|
|
||||||
|
if custom_metadata_module_path is not None:
|
||||||
|
spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
|
||||||
|
metadata_module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(metadata_module)
|
||||||
|
|
||||||
|
custom_metadata_fn = metadata_module.get_custom_metadata
|
||||||
|
|
||||||
|
if "s3_path" in wds_config:
|
||||||
|
|
||||||
|
wds_configs.append(
|
||||||
|
S3DatasetConfig(
|
||||||
|
id=wds_config["id"],
|
||||||
|
s3_path=wds_config["s3_path"],
|
||||||
|
custom_metadata_fn=custom_metadata_fn,
|
||||||
|
profile=wds_config.get("profile", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif "path" in wds_config:
|
||||||
|
|
||||||
|
wds_configs.append(
|
||||||
|
LocalWebDatasetConfig(
|
||||||
|
id=wds_config["id"],
|
||||||
|
path=wds_config["path"],
|
||||||
|
custom_metadata_fn=custom_metadata_fn
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return WebDatasetDataLoader(
|
||||||
|
wds_configs,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
sample_size=sample_size,
|
||||||
|
batch_size=batch_size,
|
||||||
|
random_crop=dataset_config.get("random_crop", True),
|
||||||
|
num_workers=num_workers,
|
||||||
|
persistent_workers=True,
|
||||||
|
force_channels=force_channels,
|
||||||
|
epoch_steps=dataset_config.get("epoch_steps", 2000)
|
||||||
|
).data_loader
|
||||||
|
|
||||||
199
stable_audio_tools/data/utils.py
Normal file
199
stable_audio_tools/data/utils.py
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from typing import Tuple
|
||||||
|
import os
|
||||||
|
import subprocess as sp
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from decord import VideoReader, cpu
|
||||||
|
|
||||||
|
class PadCrop(nn.Module):
|
||||||
|
def __init__(self, n_samples, randomize=True):
|
||||||
|
super().__init__()
|
||||||
|
self.n_samples = n_samples
|
||||||
|
self.randomize = randomize
|
||||||
|
|
||||||
|
def __call__(self, signal):
|
||||||
|
n, s = signal.shape
|
||||||
|
start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
|
||||||
|
end = start + self.n_samples
|
||||||
|
output = signal.new_zeros([n, self.n_samples])
|
||||||
|
output[:, :min(s, self.n_samples)] = signal[:, start:end]
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PadCrop_Normalized_T(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.n_samples = n_samples
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.randomize = randomize
|
||||||
|
|
||||||
|
def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int, torch.Tensor]:
|
||||||
|
n_channels, n_samples = source.shape
|
||||||
|
|
||||||
|
# Calculate the duration of the audio in seconds
|
||||||
|
total_duration = n_samples // self.sample_rate
|
||||||
|
|
||||||
|
# If the audio is shorter than the desired length, pad it
|
||||||
|
upper_bound = max(0, n_samples - self.n_samples)
|
||||||
|
|
||||||
|
# If randomize is False, always start at the beginning of the audio
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
if self.randomize and n_samples > self.n_samples:
|
||||||
|
valid_offsets = [
|
||||||
|
i * self.sample_rate for i in range(0, total_duration, 10)
|
||||||
|
if i * self.sample_rate + self.n_samples <= n_samples and
|
||||||
|
(total_duration <= 20 or total_duration - i >= 15)
|
||||||
|
]
|
||||||
|
if valid_offsets:
|
||||||
|
offset = random.choice(valid_offsets)
|
||||||
|
|
||||||
|
# Calculate the start and end times of the chunk
|
||||||
|
t_start = offset / (upper_bound + self.n_samples)
|
||||||
|
t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
|
||||||
|
|
||||||
|
# Create the chunk
|
||||||
|
chunk = source.new_zeros([n_channels, self.n_samples])
|
||||||
|
|
||||||
|
# Copy the audio into the chunk
|
||||||
|
chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
|
||||||
|
|
||||||
|
# Calculate the start and end times of the chunk in seconds
|
||||||
|
seconds_start = math.floor(offset / self.sample_rate)
|
||||||
|
seconds_total = math.ceil(n_samples / self.sample_rate)
|
||||||
|
|
||||||
|
# Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
|
||||||
|
padding_mask = torch.zeros([self.n_samples])
|
||||||
|
padding_mask[:min(n_samples, self.n_samples)] = 1
|
||||||
|
|
||||||
|
return (
|
||||||
|
chunk,
|
||||||
|
t_start,
|
||||||
|
t_end,
|
||||||
|
seconds_start,
|
||||||
|
seconds_total,
|
||||||
|
padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PhaseFlipper(nn.Module):
|
||||||
|
"Randomly invert the phase of a signal"
|
||||||
|
def __init__(self, p=0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.p = p
|
||||||
|
def __call__(self, signal):
|
||||||
|
return -signal if (random.random() < self.p) else signal
|
||||||
|
|
||||||
|
class Mono(nn.Module):
|
||||||
|
def __call__(self, signal):
|
||||||
|
return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
|
||||||
|
|
||||||
|
class Stereo(nn.Module):
|
||||||
|
def __call__(self, signal):
|
||||||
|
signal_shape = signal.shape
|
||||||
|
# Check if it's mono
|
||||||
|
if len(signal_shape) == 1: # s -> 2, s
|
||||||
|
signal = signal.unsqueeze(0).repeat(2, 1)
|
||||||
|
elif len(signal_shape) == 2:
|
||||||
|
if signal_shape[0] == 1: #1, s -> 2, s
|
||||||
|
signal = signal.repeat(2, 1)
|
||||||
|
elif signal_shape[0] > 2: #?, s -> 2,s
|
||||||
|
signal = signal[:2, :]
|
||||||
|
|
||||||
|
return signal
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_video_duration(video_tensor, duration, target_fps):
|
||||||
|
current_duration = video_tensor.shape[0]
|
||||||
|
target_duration = duration * target_fps
|
||||||
|
if current_duration > target_duration:
|
||||||
|
video_tensor = video_tensor[:target_duration]
|
||||||
|
elif current_duration < target_duration:
|
||||||
|
last_frame = video_tensor[-1:]
|
||||||
|
repeat_times = target_duration - current_duration
|
||||||
|
video_tensor = torch.cat((video_tensor, last_frame.repeat(repeat_times, 1, 1, 1)), dim=0)
|
||||||
|
return video_tensor
|
||||||
|
|
||||||
|
def read_video(filepath, seek_time=0., duration=-1, target_fps=2):
|
||||||
|
if filepath is None:
|
||||||
|
return torch.zeros((int(duration * target_fps), 3, 224, 224))
|
||||||
|
|
||||||
|
ext = os.path.splitext(filepath)[1].lower()
|
||||||
|
if ext in ['.jpg', '.jpeg', '.png']:
|
||||||
|
resize_transform = transforms.Resize((224, 224))
|
||||||
|
image = Image.open(filepath).convert("RGB")
|
||||||
|
frame = transforms.ToTensor()(image).unsqueeze(0)
|
||||||
|
frame = resize_transform(frame)
|
||||||
|
target_frames = int(duration * target_fps)
|
||||||
|
frame = frame.repeat(int(math.ceil(target_frames / frame.shape[0])), 1, 1, 1)[:target_frames]
|
||||||
|
assert frame.shape[0] == target_frames, f"The shape of frame is {frame.shape}"
|
||||||
|
return frame
|
||||||
|
|
||||||
|
vr = VideoReader(filepath, ctx=cpu(0))
|
||||||
|
fps = vr.get_avg_fps()
|
||||||
|
total_frames = len(vr)
|
||||||
|
|
||||||
|
seek_frame = int(seek_time * fps)
|
||||||
|
if duration > 0:
|
||||||
|
total_frames_to_read = int(target_fps * duration)
|
||||||
|
frame_interval = int(math.ceil(fps / target_fps))
|
||||||
|
end_frame = min(seek_frame + total_frames_to_read * frame_interval, total_frames)
|
||||||
|
frame_ids = list(range(seek_frame, end_frame, frame_interval))
|
||||||
|
else:
|
||||||
|
frame_interval = int(math.ceil(fps / target_fps))
|
||||||
|
frame_ids = list(range(0, total_frames, frame_interval))
|
||||||
|
|
||||||
|
frames = vr.get_batch(frame_ids).asnumpy()
|
||||||
|
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
if frames.shape[2] != 224 or frames.shape[3] != 224:
|
||||||
|
resize_transform = transforms.Resize((224, 224))
|
||||||
|
frames = resize_transform(frames)
|
||||||
|
|
||||||
|
video_tensor = adjust_video_duration(frames, duration, target_fps)
|
||||||
|
assert video_tensor.shape[0] == duration * target_fps, f"The shape of video_tensor is {video_tensor.shape}"
|
||||||
|
return video_tensor
|
||||||
|
|
||||||
|
def merge_video_audio(video_path, audio_path, output_path, start_time, duration):
|
||||||
|
command = [
|
||||||
|
'ffmpeg',
|
||||||
|
'-y',
|
||||||
|
'-ss', str(start_time),
|
||||||
|
'-t', str(duration),
|
||||||
|
'-i', video_path,
|
||||||
|
'-i', audio_path,
|
||||||
|
'-c:v', 'copy',
|
||||||
|
'-c:a', 'aac',
|
||||||
|
'-map', '0:v:0',
|
||||||
|
'-map', '1:a:0',
|
||||||
|
'-shortest',
|
||||||
|
'-strict', 'experimental',
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
sp.run(command, check=True)
|
||||||
|
print(f"Successfully merged audio and video into {output_path}")
|
||||||
|
return output_path
|
||||||
|
except sp.CalledProcessError as e:
|
||||||
|
print(f"Error merging audio and video: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total):
|
||||||
|
if audio_path is None:
|
||||||
|
return torch.zeros((2, int(sample_rate * seconds_total)))
|
||||||
|
audio_tensor, sr = torchaudio.load(audio_path)
|
||||||
|
start_index = int(sample_rate * seconds_start)
|
||||||
|
target_length = int(sample_rate * seconds_total)
|
||||||
|
end_index = start_index + target_length
|
||||||
|
audio_tensor = audio_tensor[:, start_index:end_index]
|
||||||
|
if audio_tensor.shape[1] < target_length:
|
||||||
|
pad_length = target_length - audio_tensor.shape[1]
|
||||||
|
audio_tensor = F.pad(audio_tensor, (pad_length, 0))
|
||||||
|
return audio_tensor
|
||||||
Loading…
Reference in a new issue