Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions examples/language/palm/train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import gzip
import random

from time import time
from functools import partial
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import tqdm
from packaging import version
from palm_pytorch import PaLM
Expand All @@ -21,7 +23,8 @@

# constants

NUM_BATCHES = int(1000)
NUM_BATCHES = int(100)
WARMUP_BATCHES = 1
GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
Expand Down Expand Up @@ -76,10 +79,18 @@ def cycle(loader):
def decode_token(token):
return str(chr(max(32, token)))

def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)

def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))

def get_model_size(model: nn.Module):
total_numel = 0
for module in model.modules():
for p in module.parameters(recurse=False):
total_numel += p.numel()
return total_numel

# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
Expand Down Expand Up @@ -143,7 +154,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())

param.visited = True


Expand All @@ -152,6 +162,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
raise TypeError(f"{args.distplan} is error")
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()

with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
Expand Down Expand Up @@ -188,7 +199,7 @@ def __len__(self):
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)

with ctx:
model = PaLM(num_tokens=256, dim=512, depth=8)
model = PaLM(num_tokens=50304, dim=4096, depth=64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果有auto regress的逻辑,tflops计算公式可能不对。

model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)

pg = default_pg
Expand All @@ -205,25 +216,42 @@ def __len__(self):
model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


# model is shared after TP
numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)

# training
model.train()

tflops_list = []
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):

if args.distplan == "colossalai":
optimizer.zero_grad()

start = time()
loss = model(next(train_loader))
fwd_end = time()
fwd_time = fwd_end - start
# loss.backward()
optimizer.backward(loss)
bwd_end = time()
bwd_time = bwd_end - fwd_end

print(f"training loss: {loss.item()}")
# print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# optim.step()
# optim.zero_grad()
optimizer.step()
optim_time = time() - bwd_end
step_time = time() - start

step_tflops = get_tflops_func(step_time)
logger.info(
f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
ranks=[0],
)
if i >= WARMUP_BATCHES:
tflops_list.append(step_tflops)

else:
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
Expand All @@ -233,6 +261,11 @@ def __len__(self):
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()

tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")


# TODO
# if i % VALIDATE_EVERY == 0:
Expand Down