Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions hidden_context/train_llm_vae_preference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
AutoModelForCausalLM,
)
from transformers.utils import PaddingStrategy
from .vae_utils import VAETrainer, VAEModel
from .vae_utils import VAETrainer, VAEModel, VQVAETrainer, VQVAE_Encoder

from .train_llm_preference_model import (
get_step_decay_lr_lambda,
Expand Down Expand Up @@ -254,6 +254,7 @@ def __call__(self, examples):

trainer_classes: Dict[RewardModelType, Type[VAETrainer]] = {
"vae": VAETrainer,
"vqvae": VQVAETrainer
}


Expand Down Expand Up @@ -607,8 +608,12 @@ def up_sample_controversial(dataset, seed):
embed_dim = script_args.embed_dim

if not script_args.use_causal_lm:
if script_args.reward_model_type == "vqvae":
num_labels = 1
else:
num_labels = embed_dim
model = AutoModelForSequenceClassification.from_pretrained(
script_args.model_name, num_labels=embed_dim, torch_dtype=torch.bfloat16
script_args.model_name, num_labels=num_labels, torch_dtype=torch.bfloat16
)
# We multiply the final linear layer's weights by 0.01 because this seems to
# significantly stabilize training and lead to better optimization of the loss.
Expand Down Expand Up @@ -653,10 +658,32 @@ def up_sample_controversial(dataset, seed):
# Train the model.
latent_dim = script_args.latent_dim
hidden_dim = script_args.hidden_dim
vae_model = VAEModel(embed_dim, hidden_dim, latent_dim, model,
fixed_contexts=script_args.fixed_contexts,
fixed_llm_embeddings=script_args.fixed_llm_embeddings,
use_causal_lm=script_args.use_causal_lm,)
if script_args.reward_model_type == "vae":
vae_model = VAEModel(embed_dim, hidden_dim, latent_dim, model,
fixed_contexts=script_args.fixed_contexts,
fixed_llm_embeddings=script_args.fixed_llm_embeddings,
use_causal_lm=script_args.use_causal_lm,)
elif script_args.reward_model_type == "vqvae":
if script_args.model_name == 'gpt2':
embed_dim = 768
if script_args.model_name == 'meta-llama/Llama-2-7b-hf':
embed_dim = 4096

if script_args.use_causal_lm:
context_dim = embed_dim
else:
context_dim = script_args.embed_dim

vae_model = VQVAE_Encoder(
script_args.n_embeddings,
embed_dim,
hidden_dim,
model,
context_dim=context_dim,
fixed_contexts=script_args.fixed_contexts,
fixed_llm_embeddings=script_args.fixed_llm_embeddings,
use_causal_lm=script_args.use_causal_lm,
)

trainer = trainer_class(
model=vae_model,
Expand Down
268 changes: 267 additions & 1 deletion hidden_context/vae_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
from transformers import Trainer, EvalPrediction
import wandb

import matplotlib.pyplot as plt

class PairEncoder(nn.Module):
"""
Expand Down Expand Up @@ -458,3 +458,269 @@ def cyclical_setter(self, value):
else:
self.cyclical = value
return


class VQVAE_Encoder(nn.Module):
def __init__(self, n_embeddings, embed_dim, hidden_dim, llm_encoder, context_dim=None, commitment_cost=0.25, decay=0.999, epsilon=1e-5, fixed_contexts=False, fixed_llm_embeddings=False, use_causal_lm=False):
super(VQVAE_Encoder, self).__init__()
self.commitment_cost = commitment_cost
self.decay = decay
self.epsilon = epsilon

self.llm_encoder = llm_encoder
self.pair_encoder = PairEncoder(context_dim, hidden_dim, embed_dim)
self.sequence_encoder = SequenceEncoder(embed_dim, embed_dim)

#TODO: initialise using llms
mean_wte = llm_encoder.transformer.wte.weight.mean(0)
weights = torch.randn(size=(n_embeddings, embed_dim), dtype=mean_wte.dtype) + mean_wte
self.embedding = nn.Parameter(weights, requires_grad=True)
# self.register_buffer("embedding", embedding)
# self.register_buffer("ema_count", torch.zeros(n_embeddings))
# self.register_buffer("ema_weight", self.embedding.clone())

self.fixed_contexts = fixed_contexts
self.fixed_llm_embeddings = fixed_llm_embeddings
self.use_causal_lm = use_causal_lm

def encode_pair(self, e_c, e_r):
return self.pair_encoder(e_c, e_r)

def encode_sequence(self, sequences, seq_start_end):
e_z, _ = self.sequence_encoder(sequences, seq_start_end)
return e_z

# def discretize(self, x):
# M, D = self.embedding.size()
# x_flat = x.detach().reshape(-1, D)

# distances = (-torch.cdist(x_flat, self.embedding, p=2)) ** 2

# indices = torch.argmin(distances.float(), dim=-1)
# quantized = F.embedding(indices, self.embedding)
# quantized = quantized.view_as(x)
# return quantized, indices

def retrieve_random_codebook(self, random_indices):
quantized = F.embedding(random_indices, self.embedding)
quantized = quantized.transpose(1, 3)

return quantized

def gt_forward(
self,
user_type,
seq_start_end,
):
quantized = self.embedding[user_type.long()]
commitment_loss = torch.Tensor([0.0])
codebook_loss = torch.Tensor([0.0])
# import pdb; pdb.set_trace()
return quantized, commitment_loss, codebook_loss, user_type #, perplexity
def forward(
self,
context_chosen,
context_rejected,
seq_start_end,
user_type,
ground_truth_user_vector=True
):
# import pdb; pdb.set_trace()
if ground_truth_user_vector:
return self.gt_forward(user_type, seq_start_end)
pair_embed = self.encode_pair(context_chosen, context_rejected)
x = self.encode_sequence(pair_embed, seq_start_end)
M, D = self.embedding.size()
x_flat = x.detach().reshape(-1, D)

distances = (-torch.cdist(x_flat, self.embedding, p=2)) ** 2

indices = torch.argmin(distances.float(), dim=-1)
encodings = F.one_hot(indices, M).float()
quantized = F.embedding(indices, self.embedding)
quantized = quantized.view_as(x)

#TODO: fix EMA loss
# if self.training:
# self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
# n = torch.sum(self.ema_count)
# self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n

# dw = torch.matmul(encodings.t(), x_flat)
# self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
# self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)

#TODO: look at how losses flow? do we need to pass in gradients to the embeddings or the codebook loss works?
codebook_loss = F.mse_loss(x_flat.detach(), quantized) * 0.1
e_latent_loss = F.mse_loss(x_flat, quantized.detach())
commitment_loss = self.commitment_cost * e_latent_loss * 0.1

quantized = x + (quantized - x).detach()

# avg_probs = torch.mean(encodings, dim=0)
# perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# import pdb; pdb.set_trace()
return quantized, commitment_loss, codebook_loss, indices #, perplexity

def save_model(self, path):
torch.save(self, path)

class VQVAETrainer(VAETrainer):
def __init__(
self, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.pad_token = torch.tensor([50256])

def compute_loss(self, wrapped_model, inputs, return_outputs=False):
model = wrapped_model # .module
device = model.llm_encoder.device
batch_size = inputs["seq_start_end"].shape[0]
self.pad_token = self.pad_token.to(device)

embeddings_chosen = model.llm_encoder.transformer.wte(inputs["input_ids_chosen"])
embeddings_rejected = model.llm_encoder.transformer.wte(inputs["input_ids_rejected"])
attention_mask_padding = torch.ones_like(inputs["attention_mask_chosen"][:, None, 0])
attention_mask_chosen = torch.cat((attention_mask_padding, inputs["attention_mask_chosen"]), dim=-1)
attention_mask_rejected = torch.cat((attention_mask_padding, inputs["attention_mask_rejected"]), dim=-1)

seq_len_chosen = (inputs["input_ids_chosen"] != self.pad_token).sum(dim=1)
seq_len_rejected = (inputs["input_ids_rejected"] != self.pad_token).sum(dim=1)
seq_len = torch.cat([seq_len_chosen, seq_len_rejected])+1

if model.fixed_contexts:
contexts_embeddings_chosen = torch.tensor(inputs["contexts_embeddings_chosen"]).to(device).bfloat16()
contexts_embeddings_rejected = torch.tensor(inputs["contexts_embeddings_rejected"]).to(device).bfloat16()
else:
if model.use_causal_lm:
last_hidden_state_chosen = model.llm_encoder(
input_ids=inputs["contexts_input_ids_chosen"],
attention_mask=inputs["contexts_attention_mask_chosen"],
output_hidden_states=True
).hidden_states[-1]
masked_last_hidden_state_chosen = last_hidden_state_chosen * inputs[
"contexts_attention_mask_chosen"].unsqueeze(-1)
token_length_chosen = torch.sum(inputs["contexts_attention_mask_chosen"], dim=1)
contexts_embeddings_chosen = torch.sum(masked_last_hidden_state_chosen,
dim=1) / token_length_chosen.unsqueeze(-1)

last_hidden_state_rejected = model.llm_encoder(
input_ids=inputs["contexts_input_ids_rejected"],
attention_mask=inputs["contexts_attention_mask_rejected"],
output_hidden_states=True
).hidden_states[-1]
masked_last_hidden_state_rejected = last_hidden_state_rejected * inputs[
"contexts_attention_mask_rejected"].unsqueeze(-1)
token_length_rejected = torch.sum(inputs["contexts_attention_mask_rejected"], dim=1)
contexts_embeddings_rejected = torch.sum(masked_last_hidden_state_rejected,
dim=1) / token_length_rejected.unsqueeze(-1)
else:
contexts_embeddings_chosen = model.llm_encoder(
inputs["contexts_input_ids_chosen"],
inputs["contexts_attention_mask_chosen"]
)[0]
contexts_embeddings_rejected = model.llm_encoder(
inputs["contexts_input_ids_rejected"],
inputs["contexts_attention_mask_rejected"]
)[0]
seq_start_end = inputs["seq_start_end"]
user_type = torch.tensor(inputs["user_type"]).to(device).bfloat16()

quantized, commitment_loss, codebook_loss, indices = model(
contexts_embeddings_chosen,
contexts_embeddings_rejected,
seq_start_end,
user_type,
ground_truth_user_vector=False # todo: set to True for debug usage
)
quantized = quantized.to(device).bfloat16()

embeddings_chosen = torch.cat((quantized[:, None], embeddings_chosen), dim=1)
embeddings_rejected = torch.cat((quantized[:, None], embeddings_rejected), dim=1)

output_dict = model.llm_encoder(
inputs_embeds=torch.concatenate(
[
embeddings_chosen,
embeddings_rejected,
],
dim=0,
),
attention_mask=torch.concatenate(
[
attention_mask_chosen,
attention_mask_rejected,
],
dim=0,
),
return_dict=True,
output_hidden_states=True
)

batch_indices = torch.arange(len(seq_len)).to(device)
hidden_states = output_dict["hidden_states"][-1][batch_indices, seq_len]
rewards = model.llm_encoder.score(hidden_states)

# rewards = rewards[0]
rewards_chosen = rewards[:batch_size]
rewards_rejected = rewards[batch_size:]

reproduction_loss = self.loss(rewards_chosen, rewards_rejected)
loss = reproduction_loss # + commitment_loss + codebook_loss

if return_outputs:
return loss, {
"rewards_chosen": rewards_chosen,
"rewards_rejected": rewards_rejected,
"commitment_loss": commitment_loss,
"codebook_loss": codebook_loss,
"z": quantized,
"user_type": user_type,
"indices": indices,
"embeddings": model.embedding
}
else:
accuracy = torch.mean((rewards_chosen > rewards_rejected).float())
self.log(
{
"rewards_chosen": rewards_chosen.mean().item(),
"rewards_rejected": rewards_rejected.mean().item(),
"train_commitment_loss": commitment_loss.item(),
"train_codebook_loss": codebook_loss.item(),
"train_loss": loss.item(),
"train_reproduction_loss": reproduction_loss.item(),
"train_accuracy": accuracy
}
)
return loss

@classmethod
def compute_metrics(cls, eval_prediction: EvalPrediction):
rewards_chosen, rewards_rejected, commitment_loss, codebook_loss, z, user_type, indices, embeddings = (
eval_prediction.predictions
)
rewards_chosen = torch.from_numpy(rewards_chosen)
rewards_rejected = torch.from_numpy(rewards_rejected)

loss = cls.per_sample_loss(rewards_chosen, rewards_rejected)
accuracy = torch.mean((rewards_chosen > rewards_rejected).float())#torch.mean((loss < np.log(2)).float())

# import pdb; pdb.set_trace()
embeddings_table = wandb.Table(columns=list(range(z.shape[1])), data=embeddings)

unique_users = np.unique(user_type)
fig, axs = plt.subplots(1, len(unique_users), figsize=(20,5))
for i, uid in enumerate(unique_users):
user_indices = indices[np.argwhere(user_type == uid)]
axs[i].hist(user_indices)
axs[i].set_title(f"User {i}")
im = wandb.Image(fig)

return {
"reproduction_loss": loss.mean().item(),
"accuracy": accuracy.item(),
"commitment_loss": commitment_loss.mean().item(),
"codebook_loss": codebook_loss.mean().item(),
"embeddings_table": embeddings_table,
"latents": im
}