@@ -403,26 +403,19 @@ def log(
403403
404404 # register logged value if it doesn't exist
405405 if key not in self :
406- self .register_key (key , meta , value )
406+ metric = _ResultMetric (meta , isinstance (value , Tensor ))
407+ self [key ] = metric
407408
408409 # check the stored metadata and the current one match
409410 elif meta != self [key ].meta :
410411 raise MisconfigurationException (
411412 f"You called `self.log({ name } , ...)` twice in `{ fx } ` with different arguments. This is not allowed"
412413 )
414+ self [key ].to (value .device )
413415
414416 batch_size = self ._extract_batch_size (self [key ], batch_size , meta )
415417 self .update_metrics (key , value , batch_size )
416418
417- def register_key (self , key : str , meta : _Metadata , value : _VALUE ) -> None :
418- """Create one _ResultMetric object per value.
419-
420- Value can be provided as a nested collection
421-
422- """
423- metric = _ResultMetric (meta , isinstance (value , Tensor )).to (value .device )
424- self [key ] = metric
425-
426419 def update_metrics (self , key : str , value : _VALUE , batch_size : int ) -> None :
427420 result_metric = self [key ]
428421 # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
0 commit comments