This repository contains the training and modeling code used to create Inworld TTS-1 and TTS-1-Max models.
You can use this code to pre-train, fine-tune, or align with RL your arbitrary SpeechLM-based TTS model, no matter if you're using a single GPU machine or a multi-GPU cluster.
- Modeling: SpeechLM and 1D audio-codecs
- Distributed Training: DDP, DeepSpeed and FSDP for training arbitrary SpeechLM
- Data Pipeline: Ready-to-use scripts to vectorize and prepare your audio data for training
The code is only tested on Ubuntu 22.04.
Component | Version | Notes |
---|---|---|
Python | 3.10 | Required for all features |
CUDA | 12.4 or 12.8 | Auto-detected |
PyTorch | 2.6 (CUDA 12.4) or 2.7 (CUDA 12.8) | Auto-installed |
This project depends on Python 3.10 and uv for package management.
Install uv for fast Python package management:
curl -LsSf https://astral.sh/uv/install.sh | sh
Default setup (CUDA 12.8 + PyTorch 2.7):
make install
Specify CUDA version:
# For CUDA 12.4 + PyTorch 2.6
make install CUDA_VERSION=12.4
# For CUDA 12.8 + PyTorch 2.7
make install CUDA_VERSION=12.8
This automatically:
- Creates Python 3.10 virtual environment
- Installs CUDA-optimized PyTorch with the proper flash attention implementation
- Sets up all project dependencies
In order to train a SpeechLM you need to first vectorize your audio-data into the audio-codes. Combined together with the transcript, the model learns how being conditioned on it generates the audio. Below example shows how to get started with a simple SFT training.
Process your raw audio dataset into a JSONL file where each line contains a sample with the following format:
{
"transcript": "Then they would swiftly dart at their prey and bear it to the ground.",
"language": "en",
"wav_path": "/path/to/audio.wav",
"duration": 3.42,
"sample_rate": 24000
}
Required fields:
transcript
: Text transcription of the audiolanguage
: Language code (e.g., "en" for English)wav_path
: Absolute path to the audio fileduration
: Audio duration in secondssample_rate
: Audio sample rate in Hz
Example dataset: You can reference the LibriTTS dataset which contains ~585 hours of English speech from 2, 456 speakers at 24kHz sampling rate.
Sample files: We provide real example data from LibriTTS in this repository:
./example/configs/samples.jsonl
- 100 real LibriTTS samples with proper JSONL format./example/wavs/
- Corresponding audio files (audio_1.wav through audio_100.wav)
This gives you a working example to test the data vectorization and training pipeline with actual audio data.
Vectorize audio data using codec's encoder (We also made the codec compatible with xcodec2, so you can use the publicly available checkpoint if you prefer not to train the codec from scratch):
Test with provided samples:
WANDB_PROJECT="your_project" \
torchrun --nproc_per_node 8 ./tools/data/data_vectorizer.py \
--codec_model_path=/path/to/codec/model.pt \
--batch_size=16 \
--dataset_path=./example/configs/samples.jsonl \
--output_dir=/path/to/output_directory \
--use_wandb \
--run_name=test_vectorization
With your own dataset:
WANDB_PROJECT="your_project" \
torchrun --nproc_per_node 8 ./tools/data/data_vectorizer.py \
--codec_model_path=/path/to/codec/model.pt \
--batch_size=16 \
--dataset_path=/path/to/your_data.jsonl \
--output_dir=/path/to/output_directory \
--use_wandb \
--run_name=vectorization_run
After vectorization completes, you'll have multiple shard files in your output directory like below:
train_codes_{0..n}.npy # Vectorized audio codes for training
train_codes_index_{0..n}.npy # Index mappings for training codes
train_samples_{0..n}.jsonl # Training sample metadata
val_codes_{0..n}.npy # Vectorized audio codes for validation
val_codes_index_{0..n}.npy # Index mappings for validation codes
val_samples_{0..n}.jsonl # Validation sample metadata
Feel free to customize filtering logic to implement your custom filtering logic.
Combine vectorized shards into unified dataset to save space:
python tools/data/data_merger.py \
--dataset_path /path/to/your/vectorized_dataset \
--remove_shards
After merging, your dataset folder will contain:
train_codes.npy # Merged training codes
train_codes_index.npy # Merged training code indices
train_samples.jsonl # Merged training samples
val_codes.npy # Merged validation codes
val_codes_index.npy # Merged validation code indices
val_samples.jsonl # Merged validation samples
Create training config ( ./example/configs/sft.json
). Below shows key configuration sections - see the full configuration file at ./example/configs/sft.json
for all available options:
{
"training": {
"seed": 777,
"logging_steps": 150,
"eval_steps": 300,
"learning_rate": 1e-04,
"batch_size": 4,
"precision": "bf16",
"strategy": "ddp"
},
"modeling": {
"parameters": {
"codebook_size": 65536,
"max_seq_len": 2048,
"model_name": "meta-llama/Llama-3.2-1B-Instruct",
"enable_text_normalization": true
}
},
"checkpointing": {
"save_steps": 10000,
"keep_only_last_n_checkpoints": 10
},
"train_weighted_datasets": {
"/path/to/your/vectorized_dataset": 1.0
},
"val_weighted_datasets": {
"/path/to/your/vectorized_dataset": 1.0
},
"dataset": {
"enable_rlhf_training": false
}
}
Important:
- Update dataset paths to point to your vectorized data directory
- This shows only key parameters - refer to
./example/configs/sft.json
for the complete configuration with all available options - To resume from a checkpoint: Add
"checkpoint_file_to_resume_from": "/path/to/your/checkpoint.pt"
to thecheckpointing
section
SFT training:
fabric run --devices=$NUM_GPU tts/training/main.py \
--config_path=./example/configs/sft.json \
--use_wandb \
--run_name=my_tts_training
After training completes, you'll find the trained model at ./experiments/my_tts_training/final_model.pt
along with model and tokenizer configuration files.
Additional options:
--dry_run
: Test pipeline without training--compile_model
: Enable torch.compile optimization (works well only if all your batch' samples have the same length)
Track progress via:
- Weights & Biases: Loss curves and training metrics
- Checkpoints: Saved every
save_steps
iterations - Console logs: Real-time training information
Once you have a trained model, you can use it for inference to generate speech from text and an audio prompt.
Use the provided inference script for easy speech generation:
python tools/inference.py \
--model_checkpoint_path /path/to/your/trained_model.pt \
--audio_encoder_path /path/to/encoder.pt \
--audio_decoder_path /path/to/decoder.pt \
--prompt_wav_path /path/to/your_prompt.wav \
--prompt_transcription "This is what the speaker says in the prompt." \
--text "Hello, this is a test of the text-to-speech system." \
--output_path output.wav
Required components:
- Trained model checkpoint (
.pt
file from training) - Audio encoder checkpoint (codec
.pt
/.ckpt
file) - Audio decoder checkpoint (codec
.pt
/.ckpt
file +model_config.json
in same directory) - Audio prompt file (
.wav
format) - Prompt transcription (text of what's spoken in the audio prompt)
Note: If you don't want to retrain the decoder, you can use the same checkpoint from xcodec2 for both encoder and decoder paths. We provided a xcodec2 compatible model_config.json
file in ./example/codec
.
make help # Show all available commands
make install # Install development environment
make test # Run test suite
make test-coverage # Run tests with coverage report
make lint # Run code linting
make lint-fix # Auto-fix linting issues
make version # Show current version or bump version
make version patch # Bump patch version (1.0.0 → 1.0.1), also support `minor`, `major`
git clone https://github.com/inworld-ai/tts.git
cd tts
# Set up development environment
make install
# Install pre-commit hooks
pre-commit install
# Make your changes and test
make lint-fix
make test-coverage
We welcome contributions! Please see our contributing guidelines:
- Fork the repository
- Create a feature branch:
git checkout -b feature/amazing-feature
- Run tests:
make test
- Run linting:
make lint-fix
- Commit changes:
git commit -m 'Add amazing feature'
- Push to branch:
git push origin feature/amazing-feature
- Open a Pull Request
- Meta AI team for open-sourcing LLaMA LLMs
- The PyTorch and Hugging Face communities
- Codec architecture inspired by Llasa
- Bug Reports: GitHub Issues
- General Questions: For general inquiries and support, please email us