Skip to content

Commit cf4384f

Browse files
committed
fix
1 parent 2a509cc commit cf4384f

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

docs/en/user_guides/test.md

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,38 @@ Test time augmentation (TTA) is a data augmentation strategy used during the tes
205205

206206
In MMDetection, we provides [DetTTAModel](../../../mmdet/models/test_time_augs/det_tta.py) class, which inherits from BaseTTAModel.
207207

208-
You can simplely run:
208+
### Use case
209+
210+
Using TTA requires two steps. First, you need to add `tta_model` and `tta_pipeline` in the configuration file:
211+
212+
```shell
213+
tta_model = dict(
214+
type='DetTTAModel',
215+
tta_cfg=dict(nms=dict(
216+
type='nms',
217+
iou_threshold=0.5),
218+
max_per_img=100))
219+
220+
tta_pipeline = [
221+
dict(type='LoadImageFromFile',
222+
file_client_args=dict(backend='disk')),
223+
dict(
224+
type='TestTimeAug',
225+
transforms=[[
226+
dict(type='Resize', scale=(1333, 800), keep_ratio=True)
227+
], [ # It uses 2 flipping enhancements (flipping and not flipping).
228+
dict(type='RandomFlip', prob=1.),
229+
dict(type='RandomFlip', prob=0.)
230+
], [
231+
dict(
232+
type='PackDetInputs',
233+
meta_keys=('img_id', 'img_path', 'ori_shape',
234+
'img_shape', 'scale_factor', 'flip',
235+
'flip_direction'))
236+
]])]
237+
```
238+
239+
Second, you can simplely run:
209240

210241
```shell
211242
# Single-gpu testing
@@ -230,8 +261,7 @@ bash tools/dist_test.sh \
230261
[--tta]
231262
```
232263

233-
By default, we only use 2 flipping enhancements (flipping and not flipping).
234-
You can also modify the config of TTA by yourself, such as adding scaling enhancement:
264+
You can also modify the TTA config by yourself, such as adding scaling enhancement:
235265

236266
```shell
237267
tta_model = dict(
@@ -263,11 +293,11 @@ tta_pipeline = [
263293
264294
The above data augmentation pipeline will first perform 3 multi-scaling enhancements on the image, followed by 2 flipping enhancements (flipping and not flipping). Finally, the image is packaged into the final result using PackDetInputs.
265295
266-
Here are some TTA configs for your reference:
296+
Here are more TTA use cases for your reference:
267297
268-
- [RetinaNet](../../../configs/_base_/tta/retinanet_tta.py)
269-
- [CenterNet](../../../configs/_base_/tta/centernet_tta.py)
270-
- [YOLOX](../../../configs/_base_/tta/rtmdet_tta_.py)
271-
- [RTMDet](../../../configs/_base_/tta/yolox_tta.py)
298+
- [RetinaNet](../../../configs/retinanet/retinanet_tta.py)
299+
- [CenterNet](../../../configs/centernet/centernet_tta.py)
300+
- [YOLOX](../../../configs/rtmdet/rtmdet_tta.py)
301+
- [RTMDet](../../../configs/yolox/yolox_tta.py)
272302
273303
For more advanced usage and data flow of TTA, please refer to [MMEngine](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/test_time_augmentation.html#data-flow). We will support instance segmentation TTA latter.

tools/test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import argparse
33
import os
44
import os.path as osp
5+
import warnings
56
from copy import deepcopy
67

78
from mmengine import ConfigDict
@@ -87,12 +88,17 @@ def main():
8788
cfg = trigger_visualization_hook(cfg, args)
8889

8990
if args.tta:
91+
9092
if 'tta_model' not in cfg:
93+
warnings.warn('Cannot find ``tta_model`` in config, '
94+
'we will set it as default.')
9195
cfg.tta_model = dict(
9296
type='DetTTAModel',
9397
tta_cfg=dict(
9498
nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))
9599
if 'tta_pipeline' not in cfg:
100+
warnings.warn('Cannot find ``tta_pipeline`` in config, '
101+
'we will set it as default.')
96102
test_data_cfg = cfg.test_dataloader.dataset
97103
while 'dataset' in test_data_cfg:
98104
test_data_cfg = test_data_cfg['dataset']

0 commit comments

Comments
 (0)