Skip to content

Commit 7a3d3b5

Browse files
author
Aryan
committed
Added first/last frame support for Veo on Vertex
1 parent f472fc5 commit 7a3d3b5

2 files changed

Lines changed: 152 additions & 98 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "externalapi-helpers"
33
description = "Various ComfyUI nodes for Gemini, Replicate and OpenAI"
4-
version = "1.0.0"
4+
version = "1.0.1"
55
license = {file = "LICENSE"}
66
# classifiers = [
77
# # For OS-independent nodes (works on all operating systems)

veo.py

Lines changed: 151 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
import time
22
import os
3+
import io
4+
import json
5+
import tempfile
36
import torch
47
import numpy as np
58
from PIL import Image
6-
import tempfile
79
import uuid
10+
from typing import Optional
811
from google import genai
9-
from google.genai.types import GenerateVideosConfig
12+
from google.genai import types
1013
import cv2
1114

12-
class Veo3VideoGenerator:
13-
14-
def __init__(self):
15-
pass
16-
15+
16+
class VeoVertexVideoGenerator:
1717
@classmethod
1818
def INPUT_TYPES(cls):
1919
return {
2020
"required": {
21+
"prompt": ("STRING", {"multiline": True, "default": "a cat reading a book"}),
2122
"project_id": ("STRING", {"multiline": False, "default": ""}),
2223
"location": ([
23-
"us-central1", "us-east1", "us-east4", "us-east5", "us-south1",
24+
"global", "us-central1", "us-east1", "us-east4", "us-east5", "us-south1",
2425
"us-west1", "us-west2", "us-west3", "us-west4",
2526
"northamerica-northeast1", "northamerica-northeast2",
2627
"southamerica-east1", "southamerica-west1", "africa-south1",
@@ -32,62 +33,73 @@ def INPUT_TYPES(cls):
3233
"asia-southeast2", "australia-southeast1", "australia-southeast2",
3334
"me-central1", "me-central2", "me-west1"
3435
], {"default": "us-central1"}),
35-
"service_account": ("STRING", {"multiline": False, "default": ""}),
36-
"prompt": ("STRING", {"multiline": True, "default": "a cat reading a book"}),
36+
"service_account": ("STRING", {"multiline": True, "default": ""}),
37+
"model": ([
38+
"veo-2.0-generate-001",
39+
"veo-2.0-generate-exp",
40+
"veo-2.0-generate-preview",
41+
"veo-3.0-generate-001",
42+
"veo-3.0-fast-generate-001",
43+
"veo-3.1-generate-001",
44+
"veo-3.1-fast-generate-001"
45+
], {"default": "veo-3.0-generate-001"}),
46+
"resolution": (["720p", "1080p"], {"default": "720p"}),
47+
"aspect_ratio": (["16:9", "9:16"], {"default": "16:9"}),
48+
"duration_seconds": ("INT", {"default": 4, "min": 4, "max": 8, "step": 1}),
49+
"seed": ("INT", {"default": 69, "min": 1, "max": 2147483646, "step": 1}),
50+
},
51+
"optional": {
3752
"negative_prompt": ("STRING", {"multiline": True, "default": ""}),
38-
"model": (["veo-3.0-generate-preview", "veo-3.0-fast-generate-preview","veo-3.0-generate-001", "veo-2.0-generate-001"], {"default": "veo-3.0-generate-001"}),
39-
"aspect_ratio": (["16:9"], {"default": "16:9"}),
40-
"generate_audio": ("BOOLEAN", {"default": False}),
41-
"seed": ("INT", {"default": -1, "min": -1, "max": 0xffffffffffffffff}),
53+
"first_frame": ("IMAGE",),
54+
"last_frame": ("IMAGE",),
4255
}
4356
}
4457

4558
RETURN_TYPES = ("IMAGE",)
4659
RETURN_NAMES = ("frames",)
4760
FUNCTION = "generate_video"
4861
CATEGORY = "video/generation"
62+
OUTPUT_IS_LIST = (True,)
4963

50-
def setup_client(self, service_account_path, project_id, location):
51-
if service_account_path.strip():
52-
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = service_account_path.strip()
53-
54-
if not os.environ.get('GOOGLE_APPLICATION_CREDENTIALS'):
55-
raise ValueError("Service account path is required.")
64+
def setup_client(self, service_account_json, project_id, location):
65+
"""Setup Vertex AI client with service account JSON content"""
66+
if not service_account_json.strip():
67+
raise ValueError("Service account JSON content is required.")
5668

5769
if not project_id.strip():
5870
raise ValueError("Project ID is required.")
5971

60-
return genai.Client(vertexai=True, project=project_id.strip(), location=location.strip())
72+
# Validate and write JSON content to temporary file
73+
try:
74+
json.loads(service_account_json) # Validate JSON format
75+
except json.JSONDecodeError as e:
76+
raise ValueError(f"Invalid JSON content: {str(e)}")
77+
78+
# Create temporary file with JSON content
79+
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False)
80+
temp_file.write(service_account_json.strip())
81+
temp_file.close()
82+
83+
# Set credentials path
84+
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = temp_file.name
85+
86+
return genai.Client(
87+
vertexai=True,
88+
project=project_id.strip(),
89+
location=location.strip()
90+
)
6191

6292
def pil_to_tensor(self, pil_image):
6393
if pil_image.mode != 'RGB':
6494
pil_image = pil_image.convert('RGB')
6595

6696
numpy_image = np.array(pil_image).astype(np.float32) / 255.0
6797
return torch.from_numpy(numpy_image).unsqueeze(0)
68-
69-
def video_to_frames(self, video_response):
98+
99+
def video_to_frames(self, video_bytes):
70100
temp_video_path = os.path.join(tempfile.gettempdir(), f"temp_video_{uuid.uuid4().hex}.mp4")
71101

72102
try:
73-
video_bytes = None
74-
75-
if hasattr(video_response, 'video_bytes'):
76-
video_bytes = video_response.video_bytes
77-
elif hasattr(video_response, 'data'):
78-
video_bytes = video_response.data
79-
elif hasattr(video_response.video, 'data'):
80-
video_bytes = video_response.video.data
81-
elif hasattr(video_response.video, 'video_bytes'):
82-
video_bytes = video_response.video.video_bytes
83-
elif hasattr(video_response.video, 'bytes'):
84-
video_bytes = video_response.video.bytes
85-
else:
86-
raise ValueError("Could not find video bytes in response")
87-
88-
if video_bytes is None:
89-
raise ValueError("Video bytes are None")
90-
91103
with open(temp_video_path, 'wb') as f:
92104
f.write(video_bytes)
93105

@@ -100,8 +112,7 @@ def video_to_frames(self, video_response):
100112
break
101113

102114
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
103-
pil_frame = Image.fromarray(frame_rgb)
104-
tensor_frame = self.pil_to_tensor(pil_frame)
115+
tensor_frame = self.pil_to_tensor(Image.fromarray(frame_rgb))
105116
frames.append(tensor_frame)
106117

107118
cap.release()
@@ -115,64 +126,107 @@ def video_to_frames(self, video_response):
115126
if os.path.exists(temp_video_path):
116127
os.remove(temp_video_path)
117128

118-
def generate_video(self, project_id, location, service_account, prompt, negative_prompt, model,
119-
aspect_ratio, generate_audio, seed):
129+
def tensor_to_image_bytes(self, image_tensor):
130+
"""Convert ComfyUI image tensor to bytes"""
131+
img_array = image_tensor.cpu().numpy() if isinstance(image_tensor, torch.Tensor) else image_tensor
120132

121-
try:
122-
client = self.setup_client(service_account, project_id, location)
123-
124-
config_params = {
125-
"aspect_ratio": aspect_ratio,
126-
"generate_audio": generate_audio
127-
}
128-
129-
if seed != -1:
130-
config_params["seed"] = seed
131-
if negative_prompt and negative_prompt.strip():
132-
config_params["negative_prompt"] = negative_prompt.strip()
133+
if len(img_array.shape) == 4:
134+
img_array = img_array[0]
135+
136+
if img_array.dtype in [np.float32, np.float64]:
137+
img_array = (img_array * 255).astype(np.uint8)
138+
139+
buffered = io.BytesIO()
140+
Image.fromarray(img_array).save(buffered, format="PNG")
141+
return buffered.getvalue()
142+
143+
def generate_video(self, prompt: str, project_id: str, location: str,
144+
service_account: str, model: str, resolution: str, aspect_ratio: str,
145+
duration_seconds: int, seed: int,
146+
negative_prompt: Optional[str] = None,
147+
first_frame: Optional[torch.Tensor] = None,
148+
last_frame: Optional[torch.Tensor] = None):
149+
150+
# Initialize Vertex AI client
151+
client = self.setup_client(service_account, project_id, location)
152+
153+
# Configure video generation
154+
config_params = {
155+
"resolution": resolution,
156+
"aspect_ratio": aspect_ratio,
157+
"duration_seconds": duration_seconds,
158+
}
159+
160+
if seed != -1:
161+
config_params["seed"] = seed
162+
163+
if negative_prompt and negative_prompt.strip():
164+
config_params["negative_prompt"] = negative_prompt.strip()
165+
166+
video_config = types.GenerateVideosConfig(**config_params)
167+
168+
# Prepare generation parameters
169+
generation_params = {
170+
"model": model,
171+
"prompt": prompt,
172+
"config": video_config,
173+
}
174+
175+
# Handle first frame image
176+
if first_frame is not None:
177+
image_bytes = self.tensor_to_image_bytes(first_frame)
178+
generation_params["image"] = types.Image(
179+
image_bytes=image_bytes,
180+
mime_type="image/png"
181+
)
182+
print("First frame image provided for video generation")
183+
184+
# Handle last frame image (dynamic attribute for preview SDK)
185+
if last_frame is not None:
186+
last_image_bytes = self.tensor_to_image_bytes(last_frame)
187+
last_frame_img = types.Image(
188+
image_bytes=last_image_bytes,
189+
mime_type="image/png"
190+
)
191+
setattr(video_config, 'last_frame', last_frame_img)
192+
print("Last frame image provided for video generation")
193+
194+
print(f"Starting video generation with model {model}...")
195+
operation = client.models.generate_videos(**generation_params)
196+
print(f"Operation started: {operation.name}")
197+
198+
# Poll for completion
199+
print("Waiting for video generation to complete...")
200+
while not operation.done:
201+
time.sleep(10)
202+
operation = client.operations.get(operation)
203+
print(".", end="", flush=True)
204+
print("")
205+
206+
# Check for errors
207+
if operation.error:
208+
raise Exception(f"Operation failed: {operation.error}")
209+
210+
# Retrieve video
211+
if not operation.result or not operation.result.generated_videos:
212+
raise Exception("No videos were generated.")
213+
214+
video_result = operation.result.generated_videos[0].video
215+
216+
if not video_result.video_bytes:
217+
raise Exception("No video bytes returned from API")
218+
219+
print("Video generated successfully. Extracting frames...")
220+
frames_tensor = self.video_to_frames(video_result.video_bytes)
221+
print(f"Extracted {frames_tensor.shape[0]} frames.")
222+
223+
return ([frames_tensor],)
133224

134-
config = GenerateVideosConfig(**config_params)
135-
136-
generation_params = {
137-
"model": model,
138-
"prompt": prompt,
139-
"config": config
140-
}
141-
142-
operation = client.models.generate_videos(**generation_params)
143-
print(f"Operation started: {operation.name}")
144-
145-
poll_interval = 15
146-
timeout_minutes = 10
147-
timeout_seconds = timeout_minutes * 60
148-
149-
start_time = time.time()
150-
while not operation.done:
151-
if time.time() - start_time > timeout_seconds:
152-
raise TimeoutError(f"Video generation timed out after {timeout_minutes} minutes.")
153-
154-
time.sleep(poll_interval)
155-
operation = client.operations.get(operation)
156-
157-
if operation.response:
158-
generated_video = operation.response.generated_videos[0]
159-
frames_tensor = self.video_to_frames(generated_video)
160-
print(f"Video generated successfully. Extracted {frames_tensor.shape[0]} frames.")
161-
return (frames_tensor,)
162-
else:
163-
error_msg = f"Generation failed: {operation.error}" if operation.error else "Unknown error."
164-
raise Exception(error_msg)
165-
166-
except Exception as e:
167-
error_msg = f"Error in video generation process: {str(e)}"
168-
print(error_msg)
169-
empty_tensor = torch.zeros((1, 64, 64, 3), dtype=torch.float32)
170-
return (empty_tensor,)
171225

172226
NODE_CLASS_MAPPINGS = {
173-
"Veo3VideoGenerator": Veo3VideoGenerator
227+
"VeoVertexVideoGenerator": VeoVertexVideoGenerator
174228
}
175229

176230
NODE_DISPLAY_NAME_MAPPINGS = {
177-
"Veo3VideoGenerator": "Veo Text-to-Video (Vertex AI)"
178-
}
231+
"VeoVertexVideoGenerator": "Veo (Vertex AI)"
232+
}

0 commit comments

Comments
 (0)