@@ -273,72 +273,74 @@ def read_video(
273273 raise RuntimeError (f"File not found: { filename } " )
274274
275275 if get_video_backend () != "pyav" :
276- return _video_opt ._read_video (filename , start_pts , end_pts , pts_unit )
277-
278- _check_av_available ()
279-
280- if end_pts is None :
281- end_pts = float ("inf" )
282-
283- if end_pts < start_pts :
284- raise ValueError (f"end_pts should be larger than start_pts, got start_pts={ start_pts } and end_pts={ end_pts } " )
285-
286- info = {}
287- video_frames = []
288- audio_frames = []
289- audio_timebase = _video_opt .default_timebase
290-
291- try :
292- with av .open (filename , metadata_errors = "ignore" ) as container :
293- if container .streams .audio :
294- audio_timebase = container .streams .audio [0 ].time_base
295- if container .streams .video :
296- video_frames = _read_from_stream (
297- container ,
298- start_pts ,
299- end_pts ,
300- pts_unit ,
301- container .streams .video [0 ],
302- {"video" : 0 },
303- )
304- video_fps = container .streams .video [0 ].average_rate
305- # guard against potentially corrupted files
306- if video_fps is not None :
307- info ["video_fps" ] = float (video_fps )
308-
309- if container .streams .audio :
310- audio_frames = _read_from_stream (
311- container ,
312- start_pts ,
313- end_pts ,
314- pts_unit ,
315- container .streams .audio [0 ],
316- {"audio" : 0 },
317- )
318- info ["audio_fps" ] = container .streams .audio [0 ].rate
319-
320- except av .AVError :
321- # TODO raise a warning?
322- pass
323-
324- vframes_list = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
325- aframes_list = [frame .to_ndarray () for frame in audio_frames ]
326-
327- if vframes_list :
328- vframes = torch .as_tensor (np .stack (vframes_list ))
329- else :
330- vframes = torch .empty ((0 , 1 , 1 , 3 ), dtype = torch .uint8 )
331-
332- if aframes_list :
333- aframes = np .concatenate (aframes_list , 1 )
334- aframes = torch .as_tensor (aframes )
335- if pts_unit == "sec" :
336- start_pts = int (math .floor (start_pts * (1 / audio_timebase )))
337- if end_pts != float ("inf" ):
338- end_pts = int (math .ceil (end_pts * (1 / audio_timebase )))
339- aframes = _align_audio_frames (aframes , audio_frames , start_pts , end_pts )
276+ vframes , aframes , info = _video_opt ._read_video (filename , start_pts , end_pts , pts_unit )
340277 else :
341- aframes = torch .empty ((1 , 0 ), dtype = torch .float32 )
278+ _check_av_available ()
279+
280+ if end_pts is None :
281+ end_pts = float ("inf" )
282+
283+ if end_pts < start_pts :
284+ raise ValueError (
285+ f"end_pts should be larger than start_pts, got start_pts={ start_pts } and end_pts={ end_pts } "
286+ )
287+
288+ info = {}
289+ video_frames = []
290+ audio_frames = []
291+ audio_timebase = _video_opt .default_timebase
292+
293+ try :
294+ with av .open (filename , metadata_errors = "ignore" ) as container :
295+ if container .streams .audio :
296+ audio_timebase = container .streams .audio [0 ].time_base
297+ if container .streams .video :
298+ video_frames = _read_from_stream (
299+ container ,
300+ start_pts ,
301+ end_pts ,
302+ pts_unit ,
303+ container .streams .video [0 ],
304+ {"video" : 0 },
305+ )
306+ video_fps = container .streams .video [0 ].average_rate
307+ # guard against potentially corrupted files
308+ if video_fps is not None :
309+ info ["video_fps" ] = float (video_fps )
310+
311+ if container .streams .audio :
312+ audio_frames = _read_from_stream (
313+ container ,
314+ start_pts ,
315+ end_pts ,
316+ pts_unit ,
317+ container .streams .audio [0 ],
318+ {"audio" : 0 },
319+ )
320+ info ["audio_fps" ] = container .streams .audio [0 ].rate
321+
322+ except av .AVError :
323+ # TODO raise a warning?
324+ pass
325+
326+ vframes_list = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
327+ aframes_list = [frame .to_ndarray () for frame in audio_frames ]
328+
329+ if vframes_list :
330+ vframes = torch .as_tensor (np .stack (vframes_list ))
331+ else :
332+ vframes = torch .empty ((0 , 1 , 1 , 3 ), dtype = torch .uint8 )
333+
334+ if aframes_list :
335+ aframes = np .concatenate (aframes_list , 1 )
336+ aframes = torch .as_tensor (aframes )
337+ if pts_unit == "sec" :
338+ start_pts = int (math .floor (start_pts * (1 / audio_timebase )))
339+ if end_pts != float ("inf" ):
340+ end_pts = int (math .ceil (end_pts * (1 / audio_timebase )))
341+ aframes = _align_audio_frames (aframes , audio_frames , start_pts , end_pts )
342+ else :
343+ aframes = torch .empty ((1 , 0 ), dtype = torch .float32 )
342344
343345 if output_format == "TCHW" :
344346 # [T,H,W,C] --> [T,C,H,W]
0 commit comments