Skip to content

Commit a11722e

Browse files
committed
listen to @eonglints and add hubert with kmeans as an option
1 parent ed313d3 commit a11722e

File tree

4 files changed

+68
-9
lines changed

4 files changed

+68
-9
lines changed

audiolm_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper
66

77
from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
8+
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

audiolm_pytorch/audiolm_pytorch.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from functools import partial
3-
from typing import Optional
3+
from typing import Optional, Union
44

55
import torch
66
from torch import nn, einsum
@@ -12,6 +12,7 @@
1212
from vector_quantize_pytorch import ResidualVQ
1313

1414
from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
15+
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans
1516

1617
# helper functions
1718

@@ -551,7 +552,7 @@ def __init__(
551552
*,
552553
num_semantic_tokens,
553554
dim,
554-
wav2vec: Optional[FairseqVQWav2Vec] = None,
555+
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
555556
**kwargs
556557
):
557558
super().__init__()
@@ -574,8 +575,8 @@ def forward(
574575

575576
if not exists(ids):
576577
assert exists(self.wav2vec)
577-
ids = self.wav2vec(raw_wave)
578-
578+
ids = self.wav2vec(raw_wave, flatten = False)
579+
579580
if return_loss:
580581
labels, ids = ids.clone(), ids[:, :-1]
581582

@@ -606,7 +607,7 @@ def __init__(
606607
codebook_size,
607608
num_coarse_quantizers,
608609
dim,
609-
wav2vec: Optional[FairseqVQWav2Vec] = None,
610+
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
610611
**kwargs
611612
):
612613
super().__init__()
@@ -840,7 +841,7 @@ def __init__(
840841
*,
841842
transformer: FineTransformer,
842843
soundstream: Optional[SoundStream] = None,
843-
wav2vec: Optional[FairseqVQWav2Vec] = None,
844+
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
844845
num_coarse_quantize = 3
845846
):
846847
super().__init__()
@@ -866,7 +867,7 @@ def forward(
866867

867868
if not exists(semantic_token_ids):
868869
assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training'
869-
semantic_token_ids = self.wav2vec(raw_wave)
870+
semantic_token_ids = self.wav2vec(raw_wave, flatten = False)
870871

871872
if not exists(coarse_token_ids):
872873
assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'

audiolm_pytorch/hubert_kmeans.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from pathlib import Path
2+
3+
import torch
4+
from torch import nn
5+
from einops import rearrange, pack, unpack
6+
7+
import joblib
8+
import fairseq
9+
10+
class HubertWithKmeans(nn.Module):
11+
def __init__(
12+
self,
13+
checkpoint_path,
14+
kmeans_path
15+
):
16+
super().__init__()
17+
model_path = Path(checkpoint_path)
18+
kmeans_path = Path(kmeans_path)
19+
20+
assert model_path.exists(), f'path {checkpoint_path} does not exist'
21+
assert kmeans_path.exists(), f'path {kmeans_path} does not exist'
22+
23+
checkpoint = torch.load(checkpoint_path)
24+
load_model_input = {checkpoint_path: checkpoint}
25+
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
26+
27+
self.model = model[0]
28+
self.model.eval()
29+
30+
kmeans = joblib.load(kmeans_path)
31+
self.kmeans = kmeans
32+
33+
@property
34+
def groups(self):
35+
return 1
36+
37+
@property
38+
def codebook_size(self):
39+
return self.kmeans.n_clusters
40+
41+
@torch.no_grad()
42+
def forward(self, wav_input, flatten = True):
43+
device = wav_input.device
44+
45+
embed = self.model(wav_input, features_only = True)
46+
embed, packed_shape = pack([embed['x']], '* d')
47+
48+
codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
49+
50+
codebook_indices = torch.from_numpy(codebook_indices).to(device).long()
51+
52+
if flatten:
53+
return codebook_indices
54+
55+
codebook_indices, = unpack(codebook_indices, packed_shape, '*')
56+
return codebook_indices

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'audiolm-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.5',
6+
version = '0.0.6',
77
license='MIT',
88
description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
99
author = 'Phil Wang',
@@ -19,9 +19,10 @@
1919
],
2020
install_requires=[
2121
'accelerate',
22-
'einops>=0.5',
22+
'einops>=0.6',
2323
'ema-pytorch',
2424
'fairseq',
25+
'joblib',
2526
'torch>=1.6',
2627
'vector-quantize-pytorch>=0.10.5'
2728
],

0 commit comments

Comments
 (0)