|
15 | 15 |
|
16 | 16 | import inspect |
17 | 17 | import math |
18 | | -from dataclasses import dataclass |
19 | 18 | from typing import Callable, Dict, List, Optional, Tuple, Union |
20 | 19 |
|
21 | 20 | import torch |
|
26 | 25 | from ...models.embeddings import get_3d_rotary_pos_embed |
27 | 26 | from ...pipelines.pipeline_utils import DiffusionPipeline |
28 | 27 | from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler |
29 | | -from ...utils import BaseOutput, logging, replace_example_docstring |
| 28 | +from ...utils import logging, replace_example_docstring |
30 | 29 | from ...utils.torch_utils import randn_tensor |
31 | 30 | from ...video_processor import VideoProcessor |
| 31 | +from .pipeline_output import CogVideoXPipelineOutput |
32 | 32 |
|
33 | 33 |
|
34 | 34 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
@@ -136,21 +136,6 @@ def retrieve_timesteps( |
136 | 136 | return timesteps, num_inference_steps |
137 | 137 |
|
138 | 138 |
|
139 | | -@dataclass |
140 | | -class CogVideoXPipelineOutput(BaseOutput): |
141 | | - r""" |
142 | | - Output class for CogVideo pipelines. |
143 | | -
|
144 | | - Args: |
145 | | - frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): |
146 | | - List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing |
147 | | - denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape |
148 | | - `(batch_size, num_frames, channels, height, width)`. |
149 | | - """ |
150 | | - |
151 | | - frames: torch.Tensor |
152 | | - |
153 | | - |
154 | 139 | class CogVideoXPipeline(DiffusionPipeline): |
155 | 140 | r""" |
156 | 141 | Pipeline for text-to-video generation using CogVideoX. |
|
0 commit comments