diff --git a/orcaexample/kinetics.py b/orcaexample/kinetics.py index 449867e6d..f747b451c 100644 --- a/orcaexample/kinetics.py +++ b/orcaexample/kinetics.py @@ -95,22 +95,22 @@ def loss_creator(config): ) val_stats = orca_estimator.evaluate(data=validation_data_creator(cfg,cfg.TEST.BATCH_SIZE)) print("===> Validation Complete: Top1Accuracy {}".format(val_stats["Accuracy"])) -# elif args.backend in ["ray", "spark"]: -# orca_estimator = Estimator.from_torch(model=model_creator, -# optimizer=optim_creator, -# loss=loss_creator, -# metrics=[Accuracy()], -# backend=args.backend, -# config=cfg, -# model_dir=os.getcwd(), -# use_tqdm=True) -# orca_estimator.fit(data=train_loader_creator, -# validation_data=validation_data_creator, -# batch_size=cfg.TRAIN.BATCH_SIZE, -# epochs=cfg.SOLVER.MAX_EPOCH) -# val_stats = orca_estimator.evaluate(data=validation_data_creator, batch_size=cfg.TEST.BATCH_SIZE) -# print("===> Validation Complete: Top1Accuracy {}".format(val_stats["Accuracy"])) -# orca_estimator.shutdown() +elif args.backend in ["ray", "spark"]: + orca_estimator = Estimator.from_torch(model=model_creator, + optimizer=optim_creator, + loss=loss_creator, + metrics=[Accuracy()], + backend=args.backend, + config=cfg, + model_dir=os.getcwd(), + use_tqdm=True) + orca_estimator.fit(data=train_loader_creator, + validation_data=validation_data_creator, + batch_size=cfg.TRAIN.BATCH_SIZE, + epochs=cfg.SOLVER.MAX_EPOCH) + val_stats = orca_estimator.evaluate(data=validation_data_creator, batch_size=cfg.TEST.BATCH_SIZE) + print("===> Validation Complete: Top1Accuracy {}".format(val_stats["Accuracy"])) + orca_estimator.shutdown() else: invalidInputError(False, "Only bigdl, ray, and spark are supported " "as the backend, but got {}".format(args.backend))