Skip to content
2 changes: 0 additions & 2 deletions Vision/classification/image/resnet50/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ bash examples/train_graph_distributed_fp16.sh
Train resnet50 with graph mode.
--use-fp16
Whether to enable amp training.
--use-gpu-decode
Use gpu to decode the data packed in ofrecord, only supported in graph mode.
--scale-grad
Whether to scale gradient when training in fp32 with graph mode.
--skip-eval
Expand Down
30 changes: 24 additions & 6 deletions Vision/classification/image/resnet50/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def parse_args(ignore_unknown_args=False):
parser = argparse.ArgumentParser(
description="OneFlow ResNet50 Arguments", allow_abbrev=False
)
parser.add_argument("--device", type=str, default="cuda", help="device: cpu, cuda...")
parser.add_argument(
"--data-loading-device",
type=str,
default="cuda",
choices=["cpu", "cuda"],
help="Specify the device for data loading: 'cpu' or 'cuda' (default: 'cuda')."
)
parser.add_argument(
"--save",
type=str,
Expand Down Expand Up @@ -60,12 +68,6 @@ def parse_args(ignore_unknown_args=False):
dest="ofrecord_part_num",
help="ofrecord data part number",
)
parser.add_argument(
"--use-gpu-decode",
action="store_true",
dest="use_gpu_decode",
help="Use gpu decode.",
)
parser.add_argument(
"--synthetic-data",
action="store_true",
Expand All @@ -86,6 +88,22 @@ def parse_args(ignore_unknown_args=False):
dest="fuse_bn_add_relu",
help="Whether to use use fuse batch_normalization, add and relu.",
)
parser.add_argument(
"--disable-fuse-add-to-output",
action="store_false",
dest="fuse_add_to_output",
help="Disable fusion of the add operation into the output (enabled by default). \n"
"For more details, see `graph_config.py` in the OneFlow repository: \n"
"https://github.com/Oneflow-Inc/oneflow",
)
parser.add_argument(
"--disable-fuse-model-update-ops",
action="store_false",
dest="fuse_model_update_ops",
help="Disable fusion of the model update operations (enabled by default). \n"
"For more details, see `graph_config.py` in the OneFlow repository: \n"
"https://github.com/Oneflow-Inc/oneflow",
)

# training hyper-parameters
parser.add_argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,5 @@ python3 $SRC_DIR/train.py \
--save $CHECKPOINT_SAVE_PATH \
--samples-per-epoch 50 \
--val-samples-per-epoch 50 \
--use-gpu-decode \
--scale-grad \
--graph \
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,4 @@ python3 -m oneflow.distributed.launch \
--metric-train-acc True \
--fuse-bn-relu \
--fuse-bn-add-relu \
--use-gpu-decode \
--channel-last \
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ python3 -m oneflow.distributed.launch \
--num-epochs $EPOCH \
--train-batch-size $TRAIN_BATCH_SIZE \
--val-batch-size $VAL_BATCH_SIZE \
--use-gpu-decode \
--scale-grad \
--graph \
--fuse-bn-relu \
Expand Down
16 changes: 9 additions & 7 deletions Vision/classification/image/resnet50/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
elif args.scale_grad:
self.set_grad_scaler(make_static_grad_scaler())

self.config.allow_fuse_add_to_output(True)
self.config.allow_fuse_model_update_ops(True)
self.config.allow_fuse_add_to_output(args.fuse_add_to_output)
self.config.allow_fuse_model_update_ops(args.fuse_model_update_ops)

# Disable cudnn_conv_heuristic_search_algo will open dry-run.
# Dry-run is better with single device, but has no effect with multiple device.
Expand All @@ -51,11 +51,12 @@ def __init__(
self.cross_entropy = cross_entropy
self.data_loader = data_loader
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
self.device = args.device

def build(self):
image, label = self.data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
loss = self.cross_entropy(logits, label)
if self.return_pred_and_label:
Expand All @@ -75,15 +76,16 @@ def __init__(self, model, data_loader):
if args.use_fp16:
self.config.enable_amp(True)

self.config.allow_fuse_add_to_output(True)
self.config.allow_fuse_add_to_output(args.fuse_add_to_output)

self.data_loader = data_loader
self.model = model
self.device = args.device

def build(self):
image, label = self.data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
pred = logits.softmax()
return pred, label
10 changes: 8 additions & 2 deletions Vision/classification/image/resnet50/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def _parse_args():
dest="image_path",
help="input image path",
)
parser.add_argument(
"--device", type=str, default="cuda", choices=["cuda", "cpu", "npu"], help="device"
)
parser.add_argument("--graph", action="store_true", help="Run model in graph mode.")
return parser.parse_args()

Expand All @@ -52,10 +55,13 @@ def build(self, image):
def main(args):
start_t = time.perf_counter()

if args.device == "npu":
import oneflow_npu

print("***** Model Init *****")
model = resnet50()
model.load_state_dict(flow.load(args.model_path))
model = model.to("cuda")
model = model.to(args.device)
model.eval()
end_t = time.perf_counter()
print(f"***** Model Init Finish, time escapled {end_t - start_t:.6f} s *****")
Expand All @@ -65,7 +71,7 @@ def main(args):

start_t = end_t
image = load_image(args.image_path)
image = flow.Tensor(image, device=flow.device("cuda"))
image = flow.Tensor(image, device=flow.device(args.device))
if args.graph:
pred = model_graph(image)
else:
Expand Down
22 changes: 12 additions & 10 deletions Vision/classification/image/resnet50/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def make_data_loader(args, mode, is_global=False, synthetic=False):
placement=placement,
sbp=sbp,
channel_last=args.channel_last,
device=args.device,
)
return data_loader.to("cuda")
return data_loader.to(args.device)

ofrecord_data_loader = OFRecordDataLoader(
ofrecord_dir=args.ofrecord_path,
Expand All @@ -44,7 +45,7 @@ def make_data_loader(args, mode, is_global=False, synthetic=False):
channel_last=args.channel_last,
placement=placement,
sbp=sbp,
use_gpu_decode=args.use_gpu_decode,
device=args.data_loading_device,
)
return ofrecord_data_loader

Expand All @@ -61,7 +62,7 @@ def __init__(
channel_last=False,
placement=None,
sbp=None,
use_gpu_decode=False,
device="cuda",
):
super().__init__()

Expand All @@ -71,6 +72,7 @@ def __init__(
self.total_batch_size = total_batch_size
self.dataset_size = dataset_size
self.mode = mode
self.device = device

random_shuffle = True if mode == "train" else False
shuffle_after_epoch = True if mode == "train" else False
Expand Down Expand Up @@ -101,9 +103,8 @@ def __init__(
rgb_mean = [123.68, 116.779, 103.939]
rgb_std = [58.393, 57.12, 57.375]

self.use_gpu_decode = use_gpu_decode
if self.mode == "train":
if self.use_gpu_decode:
if self.device == "cuda":
self.bytesdecoder_img = flow.nn.OFRecordBytesDecoder("encoded")
self.image_decoder = flow.nn.OFRecordImageGpuDecoderRandomCropResize(
target_width=image_width,
Expand Down Expand Up @@ -153,17 +154,17 @@ def __len__(self):
def forward(self):
if self.mode == "train":
record = self.ofrecord_reader()
if self.use_gpu_decode:
if self.device == "cuda":
encoded = self.bytesdecoder_img(record)
image = self.image_decoder(encoded)
else:
image_raw_bytes = self.image_decoder(record)
image = self.resize(image_raw_bytes)[0]
image = image.to("cuda")

label = self.label_decoder(record)
flip_code = self.flip()
flip_code = flip_code.to("cuda")
if self.device == "cuda":
flip_code = flip_code.to(self.device)
image = self.crop_mirror_norm(image, flip_code)
else:
record = self.ofrecord_reader()
Expand All @@ -184,6 +185,7 @@ def __init__(
placement=None,
sbp=None,
channel_last=False,
device="cuda",
):
super().__init__()

Expand Down Expand Up @@ -220,10 +222,10 @@ def __init__(
)
else:
self.image = flow.randint(
0, high=256, size=self.image_shape, dtype=flow.float32, device="cuda"
0, high=256, size=self.image_shape, dtype=flow.float32, device=device,
)
self.label = flow.randint(
0, high=self.num_classes, size=self.label_shape, device="cuda",
0, high=self.num_classes, size=self.label_shape, device=device,
).to(dtype=flow.int32)

def forward(self):
Expand Down
3 changes: 2 additions & 1 deletion Vision/classification/image/resnet50/models/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,6 @@ def forward(self, input, label):
# log_prob = input.softmax(dim=-1).log()
# onehot_label = flow.F.cast(onehot_label, log_prob.dtype)
# loss = flow.mul(log_prob * -1, onehot_label).sum(dim=-1).mean()
loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
#loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
loss = flow._C.cross_entropy(input, onehot_label.to(dtype=input.dtype), reduction='none')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss这块确定要改吗😂(cross_entropy内部貌似包含2个oplog_softmaxnll,可能效率不一定有softmax_cross_entropy好)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

临时改的,为了能跑通,等志鹏那个开发好了,就改回来。

return loss.mean()
20 changes: 13 additions & 7 deletions Vision/classification/image/resnet50/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
class Trainer(object):
def __init__(self):
args = get_args()
self.device = args.device.lower()
if self.device == "npu":
import oneflow_npu
elif self.device == "xpu":
import oneflow_xpu

for k, v in args.__dict__.items():
setattr(self, k, v)

Expand Down Expand Up @@ -89,12 +95,12 @@ def init_model(self):
start_t = time.perf_counter()

if self.is_global:
placement = flow.env.all_device_placement("cuda")
placement = flow.env.all_device_placement(self.device)
self.model = self.model.to_global(
placement=placement, sbp=flow.sbp.broadcast
)
else:
self.model = self.model.to("cuda")
self.model = self.model.to(self.device)

if self.load_path is None:
self.legacy_init_parameters()
Expand Down Expand Up @@ -276,7 +282,7 @@ def train_eager(self):
param.grad /= self.world_size
else:
loss.backward()
loss = loss / self.world_size
#loss = loss / self.world_size

self.optimizer.step()
self.optimizer.zero_grad()
Expand Down Expand Up @@ -311,8 +317,8 @@ def eval(self):

def forward(self):
image, label = self.train_data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
loss = self.cross_entropy(logits, label)
if self.metric_train_acc:
Expand All @@ -323,8 +329,8 @@ def forward(self):

def inference(self):
image, label = self.val_data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
with flow.no_grad():
logits = self.model(image)
pred = logits.softmax()
Expand Down