Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion sagemaker-triton/inferentia2/triton_inferentia2.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "60a1fdce",
"metadata": {},
Expand All @@ -9,6 +10,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a6ce5cb1",
"metadata": {},
Expand All @@ -22,6 +24,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b16f14ea",
"metadata": {},
Expand All @@ -41,6 +44,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "cf042bea",
"metadata": {},
Expand Down Expand Up @@ -158,6 +162,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7bb2cab3-c977-4d2e-b181-611b2773e30b",
"metadata": {},
Expand All @@ -166,6 +171,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "4f618f8e",
"metadata": {},
Expand All @@ -187,7 +193,7 @@
"\n",
"s3_client = boto3.client(\"s3\")\n",
"s3_client.download_file(\n",
" \"sagemaker-sample-files\", \"datasets/image/pets/shiba_inu_dog.jpg\", \"shiba_inu_dog.jpg\"\n",
" \"sagemaker-example-files-prod-us-east-2\", \"datasets/image/pets/shiba_inu_dog.jpg\", \"shiba_inu_dog.jpg\"\n",
")\n",
"\n",
"\n",
Expand All @@ -204,6 +210,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c171f622",
"metadata": {},
Expand Down Expand Up @@ -244,6 +251,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "00f0f261-e960-4a00-a9ad-8a884f9f27aa",
"metadata": {},
Expand Down Expand Up @@ -274,6 +282,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5ea0bd27-0e80-44b6-bb1e-322c34dbb9cb",
"metadata": {},
Expand Down Expand Up @@ -342,6 +351,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "554b50cb-4e32-4ad2-8d59-0391a2294c98",
"metadata": {},
Expand Down Expand Up @@ -373,6 +383,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d3c6ab5c-5991-4959-8b85-439ab44498ab",
"metadata": {},
Expand Down Expand Up @@ -459,6 +470,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "05b7fd73-2107-4705-922a-80dd7ef16833",
"metadata": {},
Expand Down Expand Up @@ -492,6 +504,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "15ac5e59-936e-4adb-a91e-67db42735307",
"metadata": {},
Expand Down Expand Up @@ -522,6 +535,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "4dc97ed0-3155-4658-96f3-7e058c801e7c",
"metadata": {},
Expand Down Expand Up @@ -592,6 +606,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "aa5cca31-bcd6-4ded-9b1a-085ee8e2094b",
"metadata": {},
Expand Down Expand Up @@ -627,6 +642,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "de473c37-7e5a-4f72-bac1-06524622e41f",
"metadata": {},
Expand Down Expand Up @@ -666,6 +682,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "fbbd1c02-fcbf-4f7c-b05b-aba9775449de",
"metadata": {},
Expand All @@ -686,6 +703,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ca3088a0-09c8-47dc-b21e-270a6f82df51",
"metadata": {},
Expand All @@ -695,6 +713,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "79ed11d0-9e14-46d4-955c-dd02c04e7867",
"metadata": {},
Expand All @@ -703,6 +722,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f079437f-f90d-4ff7-b90b-efbbc9625861",
"metadata": {},
Expand Down Expand Up @@ -781,6 +801,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "17500c6b-d59c-44bf-93af-7e1fb7fd6783",
"metadata": {},
Expand All @@ -801,6 +822,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6fe19a0a-1c82-422f-ac7b-8e3549e79145",
"metadata": {},
Expand Down Expand Up @@ -871,6 +893,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "61f8bb6f-8605-4e81-8ed2-87e9fbbc4f52",
"metadata": {},
Expand Down Expand Up @@ -945,6 +968,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8ea72aed-98b2-4a2d-9eb6-8f65c04f671e",
"metadata": {},
Expand Down Expand Up @@ -999,6 +1023,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b3561d5a-9ab0-4205-85ee-4aefacc8f849",
"metadata": {},
Expand All @@ -1008,6 +1033,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f42bec12",
"metadata": {},
Expand All @@ -1016,6 +1042,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ecd78917-ab23-46db-941b-8443c767448c",
"metadata": {},
Expand All @@ -1024,6 +1051,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "907d3eb5-acf2-4f10-843a-715e82ea51d6",
"metadata": {},
Expand Down Expand Up @@ -1099,6 +1127,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "076950ad-ab9d-44d2-9826-a463848af213",
"metadata": {},
Expand Down Expand Up @@ -1159,6 +1188,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "81a98829-7497-4e77-944d-0621719f4a71",
"metadata": {},
Expand Down Expand Up @@ -1249,6 +1279,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ddf79acd-3ad0-4e88-b746-1a831cc257c7",
"metadata": {},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import gzip
import json
import os
import h5py
from typing import List, Tuple
import random

import h5py
import numpy as np
import smdistributed.modelparallel.torch as smp
import torch


class WikiPretrainingDataset(torch.utils.data.Dataset):
class BertPretrainingDataset(torch.utils.data.Dataset):
def __init__(self, input_file, max_pred_length):
self.input_file = input_file
self.max_pred_length = max_pred_length
Expand Down Expand Up @@ -56,9 +54,15 @@ def __getitem__(self, index):
return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels]


###### Load Openwebtext pretraining data ######
class OpenwebtextPretrainingDataset(torch.utils.data.Dataset):
def __init__(self, input_paths: List[str], max_sequence_length=None, zipped=True, use_last_file_only=False):
###### Load GPT pretraining data ######
class GPTPretrainingDataset(torch.utils.data.Dataset):
def __init__(
self,
input_paths: List[str],
max_sequence_length=None,
zipped=True,
use_last_file_only=False,
):
self.input_paths = input_paths
self.max_sequence_length = max_sequence_length
self.zipped = zipped
Expand All @@ -79,11 +83,11 @@ def __read_examples(self, paths: List[str]):
self.input_data.extend([ln for _, ln in enumerate(f, 1)])
else:
if self.use_last_file_only:
with open (paths[-1], "r") as f:
with open(paths[-1], "r") as f:
self.input_data = [ln for ln in f]
else:
for path in paths:
with open (path, "r") as f:
with open(path, "r") as f:
self.input_data.extend([ln for ln in f])

# print(f'__Finished building pretraining dataset with {self.iids.shape[0]} rows__')
Expand All @@ -102,15 +106,27 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
e_idx = s_idx + self.max_sequence_length
iids = iids[s_idx:e_idx]
attns = attns[s_idx:e_idx]

# Hack to use 4096 seqlen with our existing synthetic data for benchmarking purposes only
# iids = iids.repeat(1,2).flatten()
# attns = attns.repeat(1,2).flatten()
# assert iids.shape[0] == 4096, iids.shape

return iids, attns


class DummyDataset(torch.utils.data.dataset.Dataset):
def __init__(self, length, data_type="openwebtext"):
if data_type == "openwebtext":
def __init__(self, length, data_type="GPT"):
if data_type == "GPT":
self.batch = (torch.Tensor(0), torch.Tensor(0))
elif data_type == "wiki":
self.batch = (torch.Tensor(0), torch.Tensor(0), torch.Tensor(0), torch.Tensor(0), torch.Tensor(0))
elif data_type == "BERT":
self.batch = (
torch.Tensor(0),
torch.Tensor(0),
torch.Tensor(0),
torch.Tensor(0),
torch.Tensor(0),
)
self.length = length

def __getitem__(self, index):
Expand All @@ -130,26 +146,30 @@ def create_pretraining_dataloader(
shuffle: bool = False,
zipped: bool = True,
use_last_file_only: bool = False,
data_type: str = "openwebtext",
data_type: str = "GPT",
):
if smp.pp_rank() == 0:
if data_type == "openwebtext":
data = OpenwebtextPretrainingDataset(
input_paths=input_paths, max_sequence_length=max_sequence_length, zipped=zipped, use_last_file_only=use_last_file_only
if data_type == "GPT":
data = GPTPretrainingDataset(
input_paths=input_paths,
max_sequence_length=max_sequence_length,
zipped=zipped,
use_last_file_only=use_last_file_only,
)
elif data_type == "wiki":
elif data_type == "BERT":
if len(input_paths) > 1:
print(f"Wiki data only support single file when calling create_pretraining_dataloader, reading the first file instead..")
data = WikiPretrainingDataset(input_file=input_paths[0], max_pred_length=max_sequence_length)
print(
f"BERT data only support single file when calling create_pretraining_dataloader, reading the first file instead.."
)
data = BertPretrainingDataset(
input_file=input_paths[0], max_pred_length=max_sequence_length
)
else:
raise ValueError(f"Unsupported data type {data_type}")
# TODO: set sampler.epoch to correctly shuffle across epochs, else same order will be used for all epochs
# not relevant now as we have no epochs
sampler = torch.utils.data.DistributedSampler(
data,
shuffle=shuffle,
seed=seed,
rank=dp_rank,
num_replicas=dp_size,
drop_last=True,
data, shuffle=shuffle, seed=seed, rank=dp_rank, num_replicas=dp_size, drop_last=True
)
dataloader = torch.utils.data.DataLoader(
data,
Expand All @@ -165,4 +185,4 @@ def create_pretraining_dataloader(
dataset = DummyDataset(data_len * batch_size, data_type=data_type)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True)

return dataloader
return dataloader
Loading