diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md
index 32e9e2d02..780abd43a 100644
--- a/MODEL_ZOO.md
+++ b/MODEL_ZOO.md
@@ -25,6 +25,12 @@ We provided original pretrained models from Caffe2 on heavy models (testing Caff
| X3D | M | - | 16 x 5 | 75.1 | 76.2 | 3.8 | 4.73 | [`link`](https://dl.fbaipublicfiles.com/pyslowfast/x3d_models/x3d_m.pyth) | Kinetics/X3D_M |
| X3D | L | - | 16 x 5 | 76.9 | 77.5 | 6.2 | 18.37 | [`link`](https://dl.fbaipublicfiles.com/pyslowfast/x3d_models/x3d_l.pyth) | Kinetics/X3D_L |
+## VTN model (details in projects/vtn)
+
+| architecture | backbone | pretrain | frame length x sample rate | top1 | top5 | model | config |
+| :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | :-------------: | ------------- | ------------- |
+| VTN | ViT-B | ImageNet-21K | - | 77.72 | 93.24 | [`link`](https://researchpublic.blob.core.windows.net/vtn/VTN_VIT_B_KINETICS.pyth) | Kinetics/VIT_B_VTN |
+
## AVA
| architecture | depth | Pretrain Model | frame length x sample rate | MAP | AVA version | model |
@@ -67,4 +73,4 @@ We also release the imagenet pretrained model if finetuning from ImageNet is pre
| architecture | depth | Top1 | Top5 | model |
| ------------- | ------------- | ------------- | ------------- | ------------- |
-| ResNet | R50 | 23.6 | 6.8 | [`link`](https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/R50_IN1K.pyth) |
+| ResNet | R50 | 23.6 | 6.8 | [`link`](https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/R50_IN1K.pyth) |
\ No newline at end of file
diff --git a/README.md b/README.md
index 07a7221e8..eb8c56dc1 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,7 @@ PySlowFast is an open source video understanding codebase from FAIR that provide
- [Non-local Neural Networks](https://arxiv.org/abs/1711.07971)
- [A Multigrid Method for Efficiently Training Video Models](https://arxiv.org/abs/1912.00998)
- [X3D: Progressive Network Expansion for Efficient Video Recognition](https://arxiv.org/abs/2004.04730)
+- [Video Transformer Network](https://arxiv.org/abs/2102.00719)

@@ -21,8 +22,10 @@ The goal of PySlowFast is to provide a high-performance, light-weight pytorch co
- I3D
- Non-local Network
- X3D
+- VTN
## Updates
+ - We now support [VTN Model](https://arxiv.org/abs/2102.00719). See [`projects/vtn`](./projects/vtn/README.md) for more information.
- We now support [X3D Models](https://arxiv.org/abs/2004.04730). See [`projects/x3d`](./projects/x3d/README.md) for more information.
- We now support [Multigrid Training](https://arxiv.org/abs/1912.00998) for efficiently training video models. See [`projects/multigrid`](./projects/multigrid/README.md) for more information.
- PySlowFast is released in conjunction with our [ICCV 2019 Tutorial](https://alexander-kirillov.github.io/tutorials/visual-recognition-iccv19/).
diff --git a/configs/Kinetics/VIT_B_VTN.yaml b/configs/Kinetics/VIT_B_VTN.yaml
new file mode 100644
index 000000000..ecd33b878
--- /dev/null
+++ b/configs/Kinetics/VIT_B_VTN.yaml
@@ -0,0 +1,60 @@
+TRAIN:
+ ENABLE: True
+ DATASET: kinetics
+ BATCH_SIZE: 16
+ EVAL_PERIOD: 1
+ CHECKPOINT_PERIOD: 1
+ AUTO_RESUME: True
+ EVAL_FULL_VIDEO: True
+ EVAL_NUM_FRAMES: 250
+DATA:
+ NUM_FRAMES: 16
+ SAMPLING_RATE: 8
+ TARGET_FPS: 25
+ TRAIN_JITTER_SCALES: [256, 320]
+ TRAIN_CROP_SIZE: 224
+ TEST_CROP_SIZE: 224
+ INPUT_CHANNEL_NUM: [3]
+SOLVER:
+ BASE_LR: 0.001
+ LR_POLICY: steps_with_relative_lrs
+ STEPS: [0, 13, 24]
+ LRS: [1, 0.1, 0.01]
+ MAX_EPOCH: 25
+ MOMENTUM: 0.9
+ OPTIMIZING_METHOD: sgd
+MODEL:
+ NUM_CLASSES: 400
+ ARCH: VIT
+ MODEL_NAME: VTN
+ LOSS_FUNC: cross_entropy
+ DROPOUT_RATE: 0.5
+VTN:
+ PRETRAINED: True
+ MLP_DIM: 768
+ DROP_PATH_RATE: 0.0
+ DROP_RATE: 0.0
+ HIDDEN_DIM: 768
+ MAX_POSITION_EMBEDDINGS: 288
+ NUM_ATTENTION_HEADS: 12
+ NUM_HIDDEN_LAYERS: 3
+ ATTENTION_MODE: 'sliding_chunks'
+ PAD_TOKEN_ID: -1
+ ATTENTION_WINDOW: [18, 18, 18]
+ INTERMEDIATE_SIZE: 3072
+ ATTENTION_PROBS_DROPOUT_PROB: 0.1
+ HIDDEN_DROPOUT_PROB: 0.1
+TEST:
+ ENABLE: True
+ DATASET: kinetics
+ BATCH_SIZE: 16
+ NUM_ENSEMBLE_VIEWS: 1
+ NUM_SPATIAL_CROPS: 1
+DATA_LOADER:
+ NUM_WORKERS: 8
+ PIN_MEMORY: True
+NUM_GPUS: 4
+NUM_SHARDS: 1
+RNG_SEED: 0
+OUTPUT_DIR: .
+LOG_MODEL_INFO: False
\ No newline at end of file
diff --git a/projects/vtn/README.md b/projects/vtn/README.md
new file mode 100644
index 000000000..fe4a9f0ac
--- /dev/null
+++ b/projects/vtn/README.md
@@ -0,0 +1,70 @@
+# Video Transformer Network
+Daniel Neimark, Omri Bar, Maya Zohar, Dotan Asselmann [[Paper](https://arxiv.org/abs/2102.00719)]
+
+
+

+

+
+
+
+
+## Installation
+```
+pip install timm
+pip install transformers[torch]
+```
+
+## Getting started
+To use VTN models please refer to the configs under `configs/Kinetics`, or see
+the [MODEL_ZOO.md](https://github.com/facebookresearch/SlowFast/blob/master/MODEL_ZOO.md)
+for pre-trained models*.
+
+To train ViT-B-VTN on your dataset (see [paper](https://arxiv.org/abs/2102.00719) for details):
+```
+python tools/run_net.py \
+ --cfg configs/Kinetics/VIT_B_VTN.yaml \
+ DATA.PATH_TO_DATA_DIR path_to_your_dataset \
+```
+
+To test the trained ViT-B-VTN on Kinetics-400 dataset:
+```
+python tools/run_net.py \
+ --cfg configs/Kinetics/VIT_B_VTN.yaml \
+ DATA.PATH_TO_DATA_DIR path_to_kinetics_dataset \
+ TRAIN.ENABLE False \
+ TEST.CHECKPOINT_FILE_PATH path_to_model \
+ TEST.CHECKPOINT_TYPE pytorch
+```
+
+\* VTN models in [MODEL_ZOO.md](https://github.com/facebookresearch/SlowFast/blob/master/MODEL_ZOO.md) produce slightly
+different results than those reported in the paper due to differences between the PySlowFast code base and the
+original code used to train the models (mainly around data and video loading).
+
+## Citing VTN
+If you find VTN useful for your research, please consider citing the paper using the following BibTeX entry.
+```BibTeX
+@article{neimark2021video,
+ title={Video Transformer Network},
+ author={Neimark, Daniel and Bar, Omri and Zohar, Maya and Asselmann, Dotan},
+ journal={arXiv preprint arXiv:2102.00719},
+ year={2021}
+}
+```
+
+
+## Additional Qualitative Results
+
+
+

+ Label: Tai chi. Prediction: Tai chi.
+ 
+ Label: Chopping wood. Prediction: Chopping wood.
+ 
+ Label: Archery. Prediction: Archery.
+ 
+ Label: Throwing discus. Prediction: Flying kite.
+ 
+ Label: Surfing water. Prediction: Parasailing.
+
+
+
diff --git a/projects/vtn/fig/a.png b/projects/vtn/fig/a.png
new file mode 100644
index 000000000..967106b0d
Binary files /dev/null and b/projects/vtn/fig/a.png differ
diff --git a/projects/vtn/fig/arch.png b/projects/vtn/fig/arch.png
new file mode 100644
index 000000000..b09f0fa9e
Binary files /dev/null and b/projects/vtn/fig/arch.png differ
diff --git a/projects/vtn/fig/b.png b/projects/vtn/fig/b.png
new file mode 100644
index 000000000..f29d29f5a
Binary files /dev/null and b/projects/vtn/fig/b.png differ
diff --git a/projects/vtn/fig/c.png b/projects/vtn/fig/c.png
new file mode 100644
index 000000000..e43a35319
Binary files /dev/null and b/projects/vtn/fig/c.png differ
diff --git a/projects/vtn/fig/d.png b/projects/vtn/fig/d.png
new file mode 100644
index 000000000..42ca2cde9
Binary files /dev/null and b/projects/vtn/fig/d.png differ
diff --git a/projects/vtn/fig/e.png b/projects/vtn/fig/e.png
new file mode 100644
index 000000000..c0e246fd3
Binary files /dev/null and b/projects/vtn/fig/e.png differ
diff --git a/projects/vtn/fig/vtn_demo.gif b/projects/vtn/fig/vtn_demo.gif
new file mode 100644
index 000000000..e271ce553
Binary files /dev/null and b/projects/vtn/fig/vtn_demo.gif differ
diff --git a/slowfast/config/defaults.py b/slowfast/config/defaults.py
index 718801a92..e51c5a809 100644
--- a/slowfast/config/defaults.py
+++ b/slowfast/config/defaults.py
@@ -75,6 +75,12 @@
# If set, clear all layer names according to the pattern provided.
_C.TRAIN.CHECKPOINT_CLEAR_NAME_PATTERN = () # ("backbone.",)
+# If True, will use all video's frames during evaluation
+_C.TRAIN.EVAL_FULL_VIDEO = False
+
+# In case "EVAL_FULL_VIDEO" is True, this will set the number of frames to use for the full video (250 in VTN)
+_C.TRAIN.EVAL_NUM_FRAMES = None
+
# ---------------------------------------------------------------------------- #
# Testing options
# ---------------------------------------------------------------------------- #
@@ -254,6 +260,53 @@
# pathway.
_C.SLOWFAST.FUSION_KERNEL_SZ = 5
+# -----------------------------------------------------------------------------
+# VTN options
+# -----------------------------------------------------------------------------
+_C.VTN = CfgNode()
+
+# ViT: if True, will load pretrained weights for the backbone.
+_C.VTN.PRETRAINED = True
+
+# ViT: stochastic depth decay rule.
+_C.VTN.DROP_PATH_RATE = 0.0
+
+# ViT: dropout ratio.
+_C.VTN.DROP_RATE = 0.0
+
+# Longformer: the size of the embedding, this is the input size of the MLP head,
+# and should match the ViT output dimension.
+_C.VTN.HIDDEN_DIM = 768
+
+# Longformer: the maximum sequence length that this model might ever be used with.
+_C.VTN.MAX_POSITION_EMBEDDINGS = 288
+
+# Longformer: number of attention heads for each attention layer in the Transformer encoder.
+_C.VTN.NUM_ATTENTION_HEADS = 12
+
+# Longformer: number of hidden layers in the Transformer encoder.
+_C.VTN.NUM_HIDDEN_LAYERS = 3
+
+# Longformer: Type of self-attention: LF use 'sliding_chunks' to process with a sliding window
+_C.VTN.ATTENTION_MODE = 'sliding_chunks'
+
+# Longformer: The value used to pad input_ids.
+_C.VTN.PAD_TOKEN_ID = -1
+
+# Longformer: Size of an attention window around each token.
+_C.VTN.ATTENTION_WINDOW = [18, 18, 18]
+
+# Longformer: Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+_C.VTN.INTERMEDIATE_SIZE = 3072
+
+# Longformer: The dropout ratio for the attention probabilities.
+_C.VTN.ATTENTION_PROBS_DROPOUT_PROB = 0.1
+
+# Longformer: The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+_C.VTN.HIDDEN_DROPOUT_PROB = 0.1
+
+# MLP Head: the dimension of the MLP head hidden layer.
+_C.VTN.MLP_DIM = 768
# -----------------------------------------------------------------------------
# Data options
diff --git a/slowfast/datasets/decoder.py b/slowfast/datasets/decoder.py
index efd582859..af733efa3 100644
--- a/slowfast/datasets/decoder.py
+++ b/slowfast/datasets/decoder.py
@@ -25,7 +25,7 @@ def temporal_sampling(frames, start_idx, end_idx, num_samples):
index = torch.linspace(start_idx, end_idx, num_samples)
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
frames = torch.index_select(frames, 0, index)
- return frames
+ return frames, index
def get_start_end_idx(video_size, clip_size, clip_idx, num_clips):
@@ -212,7 +212,7 @@ def torchvision_decode(
def pyav_decode(
- container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30
+ container, sampling_rate, num_frames, clip_idx, num_clips=10, target_fps=30, force_all_video=False
):
"""
Convert the video from its original fps to the target_fps. If the video
@@ -233,6 +233,7 @@ def pyav_decode(
given video.
target_fps (int): the input video may has different fps, convert it to
the target video fps before frame sampling.
+ force_all_video (bool): fetch all video's frames
Returns:
frames (tensor): decoded frames from the video. Return None if the no
video stream was found.
@@ -246,7 +247,7 @@ def pyav_decode(
frames_length = container.streams.video[0].frames
duration = container.streams.video[0].duration
- if duration is None:
+ if duration is None or force_all_video:
# If failed to fetch the decoding information, decode the entire video.
decode_all_video = True
video_start_pts, video_end_pts = 0, math.inf
@@ -290,6 +291,7 @@ def decode(
target_fps=30,
backend="pyav",
max_spatial_scale=0,
+ force_all_video=False,
):
"""
Decode the video and perform temporal sampling.
@@ -313,6 +315,7 @@ def decode(
max_spatial_scale (int): keep the aspect ratio and resize the frame so
that shorter edge size is max_spatial_scale. Only used in
`torchvision` backend.
+ force_all_video (bool): fetch all video's frames - only supported with pyav backend
Returns:
frames (tensor): decoded frames from the video.
"""
@@ -327,6 +330,7 @@ def decode(
clip_idx,
num_clips,
target_fps,
+ force_all_video,
)
elif backend == "torchvision":
frames, fps, decode_all_video = torchvision_decode(
@@ -346,11 +350,11 @@ def decode(
)
except Exception as e:
print("Failed to decode by {} with exception: {}".format(backend, e))
- return None
+ return None, None
# Return None if the frames was not decoded successfully.
if frames is None or frames.size(0) == 0:
- return None
+ return None, None
clip_sz = sampling_rate * num_frames / target_fps * fps
start_idx, end_idx = get_start_end_idx(
@@ -359,6 +363,11 @@ def decode(
clip_idx if decode_all_video else 0,
num_clips if decode_all_video else 1,
)
+
+ if force_all_video:
+ # To avoid duplicate the last frame for videos smaller then 250 frames
+ end_idx = min(float(frames.shape[0]), end_idx)
+
# Perform temporal sampling from the decoded video.
- frames = temporal_sampling(frames, start_idx, end_idx, num_frames)
- return frames
+ frames, frames_index = temporal_sampling(frames, start_idx, end_idx, num_frames)
+ return frames, frames_index
diff --git a/slowfast/datasets/kinetics.py b/slowfast/datasets/kinetics.py
index 28036573d..f402e4c52 100644
--- a/slowfast/datasets/kinetics.py
+++ b/slowfast/datasets/kinetics.py
@@ -70,6 +70,17 @@ def __init__(self, cfg, mode, num_retries=10):
cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
)
+ if self.mode in ["val", "test"] and cfg.TRAIN.EVAL_FULL_VIDEO:
+ # supporting full video evaluation
+ self.force_all_video = True
+ self.num_frames = self.cfg.TRAIN.EVAL_NUM_FRAMES
+ self.sampling_rate = 1
+ self._num_clips = 1
+ else:
+ self.force_all_video = False
+ self.num_frames = self.cfg.DATA.NUM_FRAMES
+ self.sampling_rate = self.cfg.DATA.SAMPLING_RATE
+
logger.info("Constructing Kinetics {}...".format(mode))
self._construct_loader()
@@ -158,6 +169,16 @@ def __getitem__(self, index):
/ self.cfg.MULTIGRID.DEFAULT_S
)
)
+ if self.mode in ["val"] and self.cfg.TRAIN.EVAL_FULL_VIDEO:
+ # supporting full video evaluation:
+ # spatial_sample_index=1 to take only the center
+ # The testing is deterministic and no jitter should be performed.
+ # min_scale, max_scale, and crop_size are expect to be the same.
+ # temporal_sample_index = -1 # this can be random - in the end we take [0,inf]
+ spatial_sample_index = 1
+ min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
+ max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
+ crop_size = self.cfg.DATA.TEST_CROP_SIZE
elif self.mode in ["test"]:
temporal_sample_index = (
self._spatial_temporal_idx[index]
@@ -189,7 +210,7 @@ def __getitem__(self, index):
)
sampling_rate = utils.get_random_sampling_rate(
self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE,
- self.cfg.DATA.SAMPLING_RATE,
+ self.sampling_rate,
)
# Try to decode and sample a clip from a video. If the video can not be
# decoded, repeatly find a random video replacement that can be decoded.
@@ -220,16 +241,17 @@ def __getitem__(self, index):
continue
# Decode video. Meta info is used to perform selective decoding.
- frames = decoder.decode(
+ frames, frames_index = decoder.decode(
video_container,
sampling_rate,
- self.cfg.DATA.NUM_FRAMES,
+ self.num_frames,
temporal_sample_index,
self.cfg.TEST.NUM_ENSEMBLE_VIEWS,
video_meta=self._video_meta[index],
target_fps=self.cfg.DATA.TARGET_FPS,
backend=self.cfg.DATA.DECODING_BACKEND,
max_spatial_scale=min_scale,
+ force_all_video=self.force_all_video
)
# If decoding failed (wrong format, video is too short, and etc),
@@ -263,7 +285,7 @@ def __getitem__(self, index):
)
label = self._labels[index]
- frames = utils.pack_pathway_output(self.cfg, frames)
+ frames = utils.pack_pathway_output(self.cfg, frames, frames_index)
return frames, label, index, {}
else:
raise RuntimeError(
diff --git a/slowfast/datasets/ssv2.py b/slowfast/datasets/ssv2.py
index 5e4c2a6aa..d6d5fe33d 100644
--- a/slowfast/datasets/ssv2.py
+++ b/slowfast/datasets/ssv2.py
@@ -265,6 +265,7 @@ def __getitem__(self, index):
inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
)
frames = utils.pack_pathway_output(self.cfg, frames)
+
return frames, label, index, {}
def __len__(self):
diff --git a/slowfast/datasets/utils.py b/slowfast/datasets/utils.py
index 08a4de1a6..98ecb2b4b 100644
--- a/slowfast/datasets/utils.py
+++ b/slowfast/datasets/utils.py
@@ -70,7 +70,7 @@ def get_sequence(center_idx, half_len, sample_rate, num_frames):
return seq
-def pack_pathway_output(cfg, frames):
+def pack_pathway_output(cfg, frames, frames_index=None):
"""
Prepare output as a list of tensors. Each tensor corresponding to a
unique pathway.
@@ -83,7 +83,9 @@ def pack_pathway_output(cfg, frames):
"""
if cfg.DATA.REVERSE_INPUT_CHANNEL:
frames = frames[[2, 1, 0], :, :, :]
- if cfg.MODEL.ARCH in cfg.MODEL.SINGLE_PATHWAY_ARCH:
+ if cfg.MODEL.MODEL_NAME == "VTN":
+ frame_list = [frames, frames_index]
+ elif cfg.MODEL.ARCH in cfg.MODEL.SINGLE_PATHWAY_ARCH:
frame_list = [frames]
elif cfg.MODEL.ARCH in cfg.MODEL.MULTI_PATHWAY_ARCH:
fast_pathway = frames
@@ -151,8 +153,8 @@ def spatial_sampling(
frames, _ = transform.horizontal_flip(0.5, frames)
else:
# The testing is deterministic and no jitter should be performed.
- # min_scale, max_scale, and crop_size are expect to be the same.
- assert len({min_scale, max_scale, crop_size}) == 1
+ # min_scale and max_scale are expect to be the same.
+ assert min_scale == max_scale
frames, _ = transform.random_short_side_scale_jitter(
frames, min_scale, max_scale
)
diff --git a/slowfast/models/video_model_builder.py b/slowfast/models/video_model_builder.py
index 85a4ed1a9..355c48a66 100644
--- a/slowfast/models/video_model_builder.py
+++ b/slowfast/models/video_model_builder.py
@@ -6,11 +6,12 @@
import math
import torch
import torch.nn as nn
+from timm.models.vision_transformer import vit_base_patch16_224
import slowfast.utils.weight_init_helper as init_helper
from slowfast.models.batchnorm_helper import get_norm
-from . import head_helper, resnet_helper, stem_helper
+from . import head_helper, resnet_helper, stem_helper, vtn_helper
from .build import MODEL_REGISTRY
# Number of blocks for different stages given the model depth.
@@ -758,3 +759,112 @@ def forward(self, x, bboxes=None):
for module in self.children():
x = module(x)
return x
+
+
+@MODEL_REGISTRY.register()
+class VTN(nn.Module):
+ """
+ VTN model builder. It uses ViT-Base as the backbone.
+
+ Daniel Neimark, Omri Bar, Maya Zohar and Dotan Asselmann.
+ "Video Transformer Network."
+ https://arxiv.org/abs/2102.00719
+ """
+
+ def __init__(self, cfg):
+ """
+ The `__init__` method of any subclass should also contain these
+ arguments.
+ Args:
+ cfg (CfgNode): model building configs, details are in the
+ comments of the config file.
+ """
+ super(VTN, self).__init__()
+ self._construct_network(cfg)
+
+ def _construct_network(self, cfg):
+ """
+ Builds a VTN model, with a given backbone architecture.
+ Args:
+ cfg (CfgNode): model building configs, details are in the
+ comments of the config file.
+ """
+ if cfg.MODEL.ARCH == "VIT":
+ self.backbone = vit_base_patch16_224(pretrained=cfg.VTN.PRETRAINED,
+ num_classes=0,
+ drop_path_rate=cfg.VTN.DROP_PATH_RATE,
+ drop_rate=cfg.VTN.DROP_RATE)
+ else:
+ raise NotImplementedError(f"not supporting {cfg.MODEL.ARCH}")
+
+ embed_dim = self.backbone.embed_dim
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
+
+ self.temporal_encoder = vtn_helper.VTNLongformerModel(
+ embed_dim=embed_dim,
+ max_position_embeddings=cfg.VTN.MAX_POSITION_EMBEDDINGS,
+ num_attention_heads=cfg.VTN.NUM_ATTENTION_HEADS,
+ num_hidden_layers=cfg.VTN.NUM_HIDDEN_LAYERS,
+ attention_mode=cfg.VTN.ATTENTION_MODE,
+ pad_token_id=cfg.VTN.PAD_TOKEN_ID,
+ attention_window=cfg.VTN.ATTENTION_WINDOW,
+ intermediate_size=cfg.VTN.INTERMEDIATE_SIZE,
+ attention_probs_dropout_prob=cfg.VTN.ATTENTION_PROBS_DROPOUT_PROB,
+ hidden_dropout_prob=cfg.VTN.HIDDEN_DROPOUT_PROB)
+
+ self.mlp_head = nn.Sequential(
+ nn.LayerNorm(cfg.VTN.HIDDEN_DIM),
+ nn.Linear(cfg.VTN.HIDDEN_DIM, cfg.VTN.MLP_DIM),
+ nn.GELU(),
+ nn.Dropout(cfg.MODEL.DROPOUT_RATE),
+ nn.Linear(cfg.VTN.MLP_DIM, cfg.MODEL.NUM_CLASSES)
+ )
+
+ def forward(self, x, bboxes=None):
+
+ x, position_ids = x
+
+ # spatial backbone
+ B, C, F, H, W = x.shape
+ x = x.permute(0, 2, 1, 3, 4)
+ x = x.reshape(B * F, C, H, W)
+ x = self.backbone(x)
+ x = x.reshape(B, F, -1)
+
+ # temporal encoder (Longformer)
+ B, D, E = x.shape
+ attention_mask = torch.ones((B, D), dtype=torch.long, device=x.device)
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ cls_atten = torch.ones(1).expand(B, -1).to(x.device)
+ attention_mask = torch.cat((attention_mask, cls_atten), dim=1)
+ attention_mask[:, 0] = 2
+ x, attention_mask, position_ids = vtn_helper.pad_to_window_size_local(
+ x,
+ attention_mask,
+ position_ids,
+ self.temporal_encoder.config.attention_window[0],
+ self.temporal_encoder.config.pad_token_id)
+ token_type_ids = torch.zeros(x.size()[:-1], dtype=torch.long, device=x.device)
+ token_type_ids[:, 0] = 1
+
+ # position_ids
+ position_ids = position_ids.long()
+ mask = attention_mask.ne(0).int()
+ max_position_embeddings = self.temporal_encoder.config.max_position_embeddings
+ position_ids = position_ids % (max_position_embeddings - 2)
+ position_ids[:, 0] = max_position_embeddings - 2
+ position_ids[mask == 0] = max_position_embeddings - 1
+
+ x = self.temporal_encoder(input_ids=None,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=x,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None)
+ # MLP head
+ x = x["last_hidden_state"]
+ x = self.mlp_head(x[:, 0])
+ return x
diff --git a/slowfast/models/vtn_helper.py b/slowfast/models/vtn_helper.py
new file mode 100644
index 000000000..0c7b5f831
--- /dev/null
+++ b/slowfast/models/vtn_helper.py
@@ -0,0 +1,55 @@
+import torch
+from transformers import LongformerModel, LongformerConfig
+import torch.nn.functional as F
+
+
+class VTNLongformerModel(LongformerModel):
+
+ def __init__(self,
+ embed_dim=768,
+ max_position_embeddings=2 * 60 * 60,
+ num_attention_heads=12,
+ num_hidden_layers=3,
+ attention_mode='sliding_chunks',
+ pad_token_id=-1,
+ attention_window=None,
+ intermediate_size=3072,
+ attention_probs_dropout_prob=0.1,
+ hidden_dropout_prob=0.1):
+
+ self.config = LongformerConfig()
+ self.config.attention_mode = attention_mode
+ self.config.intermediate_size = intermediate_size
+ self.config.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.config.hidden_dropout_prob = hidden_dropout_prob
+ self.config.attention_dilation = [1, ] * num_hidden_layers
+ self.config.attention_window = [256, ] * num_hidden_layers if attention_window is None else attention_window
+ self.config.num_hidden_layers = num_hidden_layers
+ self.config.num_attention_heads = num_attention_heads
+ self.config.pad_token_id = pad_token_id
+ self.config.max_position_embeddings = max_position_embeddings
+ self.config.hidden_size = embed_dim
+ super(VTNLongformerModel, self).__init__(self.config, add_pooling_layer=False)
+ self.embeddings.word_embeddings = None # to avoid distributed error of unused parameters
+
+
+def pad_to_window_size_local(input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor,
+ one_sided_window_size: int, pad_token_id: int):
+ '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer self-attention.
+ Based on _pad_to_window_size from https://github.com/huggingface/transformers:
+ https://github.com/huggingface/transformers/blob/71bdc076dd4ba2f3264283d4bc8617755206dccd/src/transformers/models/longformer/modeling_longformer.py#L1516
+ Input:
+ input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces
+ attention_mask = torch.Tensor(bsz x seqlen): attention mask
+ one_sided_window_size = int: window size on one side of each token
+ pad_token_id = int: tokenizer.pad_token_id
+ Returns
+ (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
+ '''
+ w = 2 * one_sided_window_size
+ seqlen = input_ids.size(1)
+ padding_len = (w - seqlen % w) % w
+ input_ids = F.pad(input_ids.permute(0, 2, 1), (0, padding_len), value=pad_token_id).permute(0, 2, 1)
+ attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
+ position_ids = F.pad(position_ids, (1, padding_len), value=False) # no attention on the padding tokens
+ return input_ids, attention_mask, position_ids
diff --git a/slowfast/utils/misc.py b/slowfast/utils/misc.py
index 5684ef83f..1db4758ae 100644
--- a/slowfast/utils/misc.py
+++ b/slowfast/utils/misc.py
@@ -103,7 +103,13 @@ def _get_model_analysis_input(cfg, use_train_input):
cfg.DATA.TEST_CROP_SIZE,
cfg.DATA.TEST_CROP_SIZE,
)
- model_inputs = pack_pathway_output(cfg, input_tensors)
+
+ if cfg.MODEL.MODEL_NAME == "VTN":
+ frames_index = torch.arange(input_tensors.shape[1])
+ else:
+ frames_index = None
+
+ model_inputs = pack_pathway_output(cfg, input_tensors, frames_index)
for i in range(len(model_inputs)):
model_inputs[i] = model_inputs[i].unsqueeze(0)
if cfg.NUM_GPUS: