The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection", in ICASSP 2022.
In this paper, we devise a model, HTS-AT, by combining a swin transformer with a token-semantic module and adapt it in to audio classification and sound event detection tasks. HTS-AT is an efficient and light-weight audio transformer with a hierarchical structure and has only 30 million parameters. It achieves new state-of-the-art (SOTA) results on AudioSet and ESC-50, and equals the SOTA on Speech Command V2. It also achieves better performance in event localization than the previous CNN-based models.
Below is the updated README in English with added instructions for preparing the htsat_esc_training.ipynb
file:
If you only want to start a quick inference setup for your audio, please refer to this script. Thanks @allenhung1025 for contribuing this.
- Download the checkpoint "HTSAT_AudioSet_Saved_1.ckpt" from link.
- Install cog
- run
cog predict -i audio=@{wav_file} --use-cog-base-image=false
- expected output
Running prediction...
[
[
137,
"Music",
0.5800321102142334
],
[
0,
"Speech",
0.5537758469581604
],
[
472,
"Whip",
0.5477684736251831
]
]
You can install dependencies via:
pip install -r requirements.txt
Here we do not include the installation of PyTorch in the requirment, since different machines require different vereions of CUDA and Toolkits. So make sure you install the PyTorch from the official guidance.
Another way is to create a conda environment (replace your_env_name
with your desired environment name) with all the necessary dependencies:
conda create -n your_env_name -c pytorch -c nvidia -c conda-forge pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 sox ffmpeg h5py=3.6.0 librosa==0.8.1 matplotlib==3.5.1 numpy==1.22 pandas==1.4.0 scikit-learn==1.0.2 scipy==1.7.3 tensorboard==2.8.0 pytorch-lightning==1.5.9
Here we use pytorch 2.5.1 as an example.
Activate the environment:
conda activate your_env_name
Install extra packages using Pip:
pip install museval==0.4.0 torchcontrib==0.0.2 torchlibrosa==0.0.9 tqdm==4.62.3 wget notebook ipywidgets gdown
If you are running on Linux, ensure that SOX and ffmpeg are installed. Although these packages are included in the Conda command, you may also install them manually if needed:
sudo apt install sox
conda install -c conda-forge ffmpeg
Before starting training or evaluation, you must prepare your datasets. Edit the config.py
file to modify the following variables:
dataset_path
: Path to your processed dataset folder.desed_folder
: Path to your DESED folder (if applicable).classes_num
: Number of classes (e.g., 527 for AudioSet).
-
Index the Data:
Adjust the paths in the
./create_index.sh
script if necessary, then run:./create_index.sh
-
Save Class Information:
Count the number of samples per class and save the information to
.npy
files:python main.py save_idc
Open the notebook esc-50/prep_esc50.ipynb
and run the cells to process the dataset.
Open the notebook scv2/prep_scv2.ipynb
and run the cells to process the dataset.
Generate the .npy
data files from the DESED dataset by running:
python convert_desed.py
Pre-trained model checkpoints for AudioSet, ESC-50, Speech Command V2, and DESED are provided. Feel free to download and test these checkpoints.
The config.py
file contains all the configuration settings required to run the code. Read the introductory comments in the file and adjust the settings according to your needs.
IMPORTANT:
Like many Transformer-based models, HTS-AT requires a warm-up phase to prevent underfitting at the beginning of training. The default settings are tuned for the full AudioSet (2.2M samples). If your dataset size differs (e.g., 100K, 1M, 10M samples, etc.), you might need to adjust the warm-up steps or epochs accordingly.
-
AudioSet:
dataset_path = "path/to/your/processed/audioset" dataset_type = "audioset" balanced_data = True loss_type = "clip_bce" sample_rate = 32000 hop_size = 320 classes_num = 527
-
ESC-50:
dataset_path = "path/to/your/processed/esc50" dataset_type = "esc-50" loss_type = "clip_ce" sample_rate = 32000 hop_size = 320 classes_num = 50
-
Speech Command V2:
dataset_path = "path/to/your/processed/scv2" dataset_type = "scv2" loss_type = "clip_bce" sample_rate = 16000 hop_size = 160 classes_num = 35
-
DESED:
resume_checkpoint = "path/to/your/audioset_checkpoint" heatmap_dir = "directory_for_localization_results" test_file = "heatmap_output_filename" fl_local = True fl_dataset = "path/to/your/desed_npy_file"
Note: The model currently supports single GPU training/testing.
All scripts are executed via main.py
.
-
Training:
CUDA_VISIBLE_DEVICES=0 python main.py train
-
Testing:
CUDA_VISIBLE_DEVICES=0 python main.py test
-
Ensemble Testing:
CUDA_VISIBLE_DEVICES=0 python main.py esm_test
(Check the ensemble settings in
config.py
.) -
Weight Averaging:
python main.py weight_average
To perform localization on the DESED dataset:
-
Ensure
fl_local=True
inconfig.py
. -
Run the test:
CUDA_VISIBLE_DEVICES=0 python main.py test
-
Organize and gather the localization results:
python fl_evaluate.py
-
You can also use the notebook
fl_evaluate_f1.ipynb
to produce the final localization results.
This repository also includes the htsat_esc_training.ipynb
notebook, which is specifically designed for training the model on the ESC-50 dataset. To prepare and use this notebook:
- Configure for ESC-50:
Openconfig.py
and set the following parameters for ESC-50:dataset_path = "path/to/your/processed/esc50" dataset_type = "esc-50" loss_type = "clip_ce" sample_rate = 32000 hop_size = 320 classes_num = 50
- Open the Notebook:
Launch Jupyter Notebook:and open thejupyter notebook
htsat_esc_training.ipynb
file located in the repository root. - Run the Cells:
Execute each cell sequentially. The notebook handles data preprocessing, model initialization, and training specific to the ESC-50 dataset. Follow the inline comments for detailed guidance.
If you use this work in your research, please cite:
@inproceedings{htsat-ke2022,
author = {Ke Chen and Xingjian Du and Bilei Zhu and Zejun Ma and Taylor Berg-Kirkpatrick and Shlomo Dubnov},
title = {HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection},
booktitle = {{ICASSP} 2022}
}
Our work is based on Swin Transformer, which is a famous image classification transformer model.