File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change 1212from __future__ import annotations
1313
1414from collections .abc import Sequence
15+ from typing import Optional
1516
1617import numpy as np
1718import torch
@@ -53,7 +54,7 @@ def __init__(
5354 pos_embed_type : str = "learnable" ,
5455 dropout_rate : float = 0.0 ,
5556 spatial_dims : int = 3 ,
56- pos_embed_kwargs : dict = {} ,
57+ pos_embed_kwargs : Optional [ dict ] = None ,
5758 ) -> None :
5859 """
5960 Args:
@@ -108,6 +109,8 @@ def __init__(
108109 self .position_embeddings = nn .Parameter (torch .zeros (1 , self .n_patches , hidden_size ))
109110 self .dropout = nn .Dropout (dropout_rate )
110111
112+ pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs
113+
111114 if self .pos_embed_type == "none" :
112115 pass
113116 elif self .pos_embed_type == "learnable" :
You can’t perform that action at this time.
0 commit comments