@@ -36,15 +36,13 @@ def __init__(
36
36
The total number of positional features will be max_freqs^2.
37
37
"""
38
38
super ().__init__ ()
39
- self .dtype = dtype
39
+ self .dtype = dtype
40
40
self .max_freqs = max_freqs
41
41
self .hidden_size_input = hidden_size_input
42
42
43
43
# A linear layer to project the concatenated input features and
44
44
# positional encodings to the final output dimension.
45
- self .embedder = nn .Sequential (
46
- operations .Linear (in_channels + max_freqs ** 2 , hidden_size_input , dtype = dtype , device = device )
47
- )
45
+ self .embedder = operations .Linear (in_channels + max_freqs ** 2 , hidden_size_input , dtype = dtype , device = device )
48
46
49
47
@lru_cache (maxsize = 4 )
50
48
def fetch_pos (self , patch_size : int , device : torch .device , dtype : torch .dtype ) -> torch .Tensor :
@@ -101,7 +99,7 @@ def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -
101
99
102
100
return dct
103
101
104
- def forward (self , inputs : torch .Tensor , embedder_dtype : torch . dtype ) -> torch .Tensor :
102
+ def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
105
103
"""
106
104
Forward pass for the embedder.
107
105
@@ -117,16 +115,11 @@ def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Te
117
115
# Infer the patch side length from the number of pixels (P^2).
118
116
patch_size = int (P2 ** 0.5 )
119
117
120
- # Possibly run the operation with a different dtype.
121
118
input_dtype = inputs .dtype
122
- if embedder_dtype != input_dtype or self .dtype != input_dtype :
123
- embedder = self .embedder .to (dtype = embedder_dtype )
124
- inputs = inputs .to (dtype = embedder_dtype )
125
- else :
126
- embedder = self .embedder
119
+ inputs = inputs .to (dtype = self .dtype )
127
120
128
121
# Fetch the pre-computed or cached positional embeddings.
129
- dct = self .fetch_pos (patch_size , inputs .device , embedder_dtype )
122
+ dct = self .fetch_pos (patch_size , inputs .device , self . dtype )
130
123
131
124
# Repeat the positional embeddings for each item in the batch.
132
125
dct = dct .repeat (B , 1 , 1 )
@@ -136,10 +129,7 @@ def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Te
136
129
inputs = torch .cat ((inputs , dct ), dim = - 1 )
137
130
138
131
# Project the combined tensor to the target hidden size.
139
- inputs = embedder (inputs )
140
-
141
- # No-op if already the same dtype.
142
- return inputs .to (dtype = input_dtype )
132
+ return self .embedder (inputs ).to (dtype = input_dtype )
143
133
144
134
145
135
class NerfGLUBlock (nn .Module ):
0 commit comments