Skip to content

Commit 36ba079

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Fixes to CO3Dv2 provider.
Summary: 1. Random sampling of num batches without replacement not supported. 2.Providers should implement the interface for the training loop to work. Reviewed By: bottler, davnov134 Differential Revision: D37815388 fbshipit-source-id: 8a2795b524e733f07346ffdb20a9c0eb1a2b8190
1 parent b95ec19 commit 36ba079

File tree

4 files changed

+63
-2
lines changed

4 files changed

+63
-2
lines changed

projects/implicitron_trainer/tests/experiment.yaml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,47 @@ data_source_args:
335335
sort_frames: false
336336
path_manager_factory_PathManagerFactory_args:
337337
silence_logs: true
338+
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
339+
category: ???
340+
subset_name: ???
341+
dataset_root: ''
342+
test_on_train: false
343+
only_test_set: false
344+
load_eval_batches: true
345+
dataset_class_type: JsonIndexDataset
346+
path_manager_factory_class_type: PathManagerFactory
347+
dataset_JsonIndexDataset_args:
348+
path_manager: null
349+
frame_annotations_file: ''
350+
sequence_annotations_file: ''
351+
subset_lists_file: ''
352+
subsets: null
353+
limit_to: 0
354+
limit_sequences_to: 0
355+
pick_sequence: []
356+
exclude_sequence: []
357+
limit_category_to: []
358+
dataset_root: ''
359+
load_images: true
360+
load_depths: true
361+
load_depth_masks: true
362+
load_masks: true
363+
load_point_clouds: false
364+
max_points: 0
365+
mask_images: false
366+
mask_depths: false
367+
image_height: 800
368+
image_width: 800
369+
box_crop: true
370+
box_crop_mask_thr: 0.4
371+
box_crop_context: 0.3
372+
remove_empty_masks: true
373+
n_frames_per_sequence: -1
374+
seed: 0
375+
sort_frames: false
376+
eval_batches: null
377+
path_manager_factory_PathManagerFactory_args:
378+
silence_logs: true
338379
dataset_map_provider_LlffDatasetMapProvider_args:
339380
base_dir: ???
340381
object_name: ???

pytorch3d/implicitron/dataset/data_loader_map_provider.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,13 @@ def _simple_loader(
354354
"""
355355
if num_batches > 0:
356356
num_samples = self.batch_size * num_batches
357+
replacement = True
357358
else:
358359
num_samples = None
359-
sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples)
360+
replacement = False
361+
sampler = RandomSampler(
362+
dataset, replacement=replacement, num_samples=num_samples
363+
)
360364
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
361365
return DataLoader(
362366
dataset,

pytorch3d/implicitron/dataset/data_source.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
1818
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
1919
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
20+
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
2021
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
2122

2223

pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
import logging
1010
import os
1111
import warnings
12-
from typing import Dict, List, Type
12+
from typing import Dict, List, Optional, Type
1313

1414
from pytorch3d.implicitron.dataset.dataset_map_provider import (
1515
DatasetMap,
1616
DatasetMapProviderBase,
1717
PathManagerFactory,
18+
Task,
1819
)
1920
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
2021
from pytorch3d.implicitron.tools.config import (
@@ -23,6 +24,8 @@
2324
run_auto_creation,
2425
)
2526

27+
from pytorch3d.renderer.cameras import CamerasBase
28+
2629

2730
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
2831

@@ -296,6 +299,18 @@ def get_category_to_subset_name_list(self) -> Dict[str, List[str]]:
296299
)
297300
return category_to_subset_name_list
298301

302+
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
303+
return {
304+
"manyview": Task.SINGLE_SEQUENCE,
305+
"fewview": Task.MULTI_SEQUENCE,
306+
}[self.subset_name.split("_")[0]]
307+
308+
def get_all_train_cameras(self) -> Optional[CamerasBase]:
309+
# pyre-ignore[16]
310+
train_dataset = self.dataset_map.train
311+
assert isinstance(train_dataset, JsonIndexDataset)
312+
return train_dataset.get_all_train_cameras()
313+
299314
def _load_annotation_json(self, json_filename: str):
300315
full_path = os.path.join(
301316
self.dataset_root,

0 commit comments

Comments
 (0)