Skip to content

Commit 5b2a48b

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
feat: Support Inline Source Deployment in Agent Engine
PiperOrigin-RevId: 817827304
1 parent 9d1cd6e commit 5b2a48b

File tree

4 files changed

+650
-61
lines changed

4 files changed

+650
-61
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# limitations under the License.
1414
#
1515
import asyncio
16+
import base64
1617
import importlib
18+
import io
1719
import json
1820
import logging
1921
import os
2022
import sys
23+
import tarfile
2124
import tempfile
2225
from typing import Any, AsyncIterable, Dict, Iterable, List
2326
from unittest import mock
@@ -901,6 +904,48 @@ def test_create_agent_engine_config_full(self, mock_prepare):
901904
== _TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT
902905
)
903906

907+
@mock.patch.object(
908+
_agent_engines_utils,
909+
"_create_base64_encoded_tarball",
910+
return_value="test_tarball",
911+
)
912+
def test_create_agent_engine_config_with_source_packages(
913+
self, mock_create_base64_encoded_tarball
914+
):
915+
with tempfile.TemporaryDirectory() as tmpdir:
916+
test_file_path = os.path.join(tmpdir, "test_file.txt")
917+
with open(test_file_path, "w") as f:
918+
f.write("test content")
919+
requirements_file_path = os.path.join(tmpdir, "requirements.txt")
920+
with open(requirements_file_path, "w") as f:
921+
f.write("requests==2.0.0")
922+
923+
config = self.client.agent_engines._create_config(
924+
mode="create",
925+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
926+
description=_TEST_AGENT_ENGINE_DESCRIPTION,
927+
source_packages=[test_file_path],
928+
entrypoint_module="main",
929+
entrypoint_object="app",
930+
requirements_file=requirements_file_path,
931+
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
932+
)
933+
assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME
934+
assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION
935+
assert config["spec"]["source_code_spec"] == {
936+
"inline_source": {"source_archive": "test_tarball"},
937+
"python_spec": {
938+
"version": _TEST_PYTHON_VERSION,
939+
"entrypoint_module": "main",
940+
"entrypoint_object": "app",
941+
"requirements_file": requirements_file_path,
942+
},
943+
}
944+
assert config["spec"]["class_methods"] == _TEST_AGENT_ENGINE_CLASS_METHODS
945+
mock_create_base64_encoded_tarball.assert_called_once_with(
946+
source_packages=[test_file_path]
947+
)
948+
904949
@mock.patch.object(_agent_engines_utils, "_prepare")
905950
def test_update_agent_engine_config_full(self, mock_prepare):
906951
config = self.client.agent_engines._create_config(
@@ -951,10 +996,10 @@ def test_update_agent_engine_config_full(self, mock_prepare):
951996
"spec.package_spec.pickle_object_gcs_uri",
952997
"spec.package_spec.dependency_files_gcs_uri",
953998
"spec.package_spec.requirements_gcs_uri",
999+
"spec.class_methods",
9541000
"spec.deployment_spec.env",
9551001
"spec.deployment_spec.secret_env",
9561002
"spec.service_account",
957-
"spec.class_methods",
9581003
"spec.agent_framework",
9591004
]
9601005
)
@@ -1170,6 +1215,19 @@ def test_to_parsed_json(self, obj, expected):
11701215
for got, want in zip(_agent_engines_utils._yield_parsed_json(obj), expected):
11711216
assert got == want
11721217

1218+
def test_create_base64_encoded_tarball(self):
1219+
with tempfile.TemporaryDirectory() as tmpdir:
1220+
test_file_path = os.path.join(tmpdir, "test_file.txt")
1221+
with open(test_file_path, "w") as f:
1222+
f.write("test content")
1223+
encoded_tarball = _agent_engines_utils._create_base64_encoded_tarball(
1224+
source_packages=[test_file_path]
1225+
)
1226+
decoded_tarball = base64.b64decode(encoded_tarball)
1227+
with tarfile.open(fileobj=io.BytesIO(decoded_tarball), mode="r:gz") as tar:
1228+
names = tar.getnames()
1229+
assert test_file_path.strip("/") in names
1230+
11731231

11741232
@pytest.mark.usefixtures("google_auth_mock")
11751233
class TestAgentEngine:
@@ -1365,6 +1423,10 @@ def test_create_agent_engine_with_env_vars_dict(
13651423
agent_server_mode=None,
13661424
labels=None,
13671425
class_methods=None,
1426+
source_packages=None,
1427+
entrypoint_module=None,
1428+
entrypoint_object=None,
1429+
requirements_file=None,
13681430
)
13691431
request_mock.assert_called_with(
13701432
"post",
@@ -1447,6 +1509,10 @@ def test_create_agent_engine_with_custom_service_account(
14471509
labels=None,
14481510
agent_server_mode=None,
14491511
class_methods=None,
1512+
source_packages=None,
1513+
entrypoint_module=None,
1514+
entrypoint_object=None,
1515+
requirements_file=None,
14501516
)
14511517
request_mock.assert_called_with(
14521518
"post",
@@ -1531,6 +1597,10 @@ def test_create_agent_engine_with_experimental_mode(
15311597
labels=None,
15321598
agent_server_mode=_genai_types.AgentServerMode.EXPERIMENTAL,
15331599
class_methods=None,
1600+
source_packages=None,
1601+
entrypoint_module=None,
1602+
entrypoint_object=None,
1603+
requirements_file=None,
15341604
)
15351605
request_mock.assert_called_with(
15361606
"post",
@@ -1553,6 +1623,72 @@ def test_create_agent_engine_with_experimental_mode(
15531623
None,
15541624
)
15551625

1626+
@mock.patch.object(
1627+
_agent_engines_utils,
1628+
"_create_base64_encoded_tarball",
1629+
return_value="test_tarball",
1630+
)
1631+
@mock.patch.object(_agent_engines_utils, "_await_operation")
1632+
def test_create_agent_engine_with_source_packages(
1633+
self,
1634+
mock_await_operation,
1635+
mock_create_base64_encoded_tarball,
1636+
):
1637+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
1638+
response=_genai_types.ReasoningEngine(
1639+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
1640+
spec=_TEST_AGENT_ENGINE_SPEC,
1641+
)
1642+
)
1643+
with tempfile.TemporaryDirectory() as tmpdir:
1644+
test_file_path = os.path.join(tmpdir, "test_file.txt")
1645+
with open(test_file_path, "w") as f:
1646+
f.write("test content")
1647+
requirements_file_path = os.path.join(tmpdir, "requirements.txt")
1648+
with open(requirements_file_path, "w") as f:
1649+
f.write("requests==2.0.0")
1650+
1651+
with mock.patch.object(
1652+
self.client.agent_engines._api_client, "request"
1653+
) as request_mock:
1654+
request_mock.return_value = genai_types.HttpResponse(body="")
1655+
self.client.agent_engines.create(
1656+
config=_genai_types.AgentEngineConfig(
1657+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1658+
description=_TEST_AGENT_ENGINE_DESCRIPTION,
1659+
source_packages=[test_file_path],
1660+
entrypoint_module="main",
1661+
entrypoint_object="app",
1662+
requirements_file=requirements_file_path,
1663+
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
1664+
),
1665+
)
1666+
request_mock.assert_called_with(
1667+
"post",
1668+
"reasoningEngines",
1669+
{
1670+
"displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1671+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1672+
"spec": {
1673+
"agent_framework": "custom",
1674+
"source_code_spec": {
1675+
"inline_source": {"source_archive": "test_tarball"},
1676+
"python_spec": {
1677+
"version": _TEST_PYTHON_VERSION,
1678+
"entrypoint_module": "main",
1679+
"entrypoint_object": "app",
1680+
"requirements_file": requirements_file_path,
1681+
},
1682+
},
1683+
"class_methods": _TEST_AGENT_ENGINE_CLASS_METHODS,
1684+
},
1685+
},
1686+
None,
1687+
)
1688+
mock_create_base64_encoded_tarball.assert_called_once_with(
1689+
source_packages=[test_file_path]
1690+
)
1691+
15561692
@mock.patch.object(agent_engines.AgentEngines, "_create_config")
15571693
@mock.patch.object(_agent_engines_utils, "_await_operation")
15581694
def test_create_agent_engine_with_class_methods(
@@ -1613,6 +1749,10 @@ def test_create_agent_engine_with_class_methods(
16131749
labels=None,
16141750
agent_server_mode=None,
16151751
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
1752+
source_packages=None,
1753+
entrypoint_module=None,
1754+
entrypoint_object=None,
1755+
requirements_file=None,
16161756
)
16171757
request_mock.assert_called_with(
16181758
"post",
@@ -1772,9 +1912,9 @@ def test_update_agent_engine_env_vars(
17721912
[
17731913
"spec.package_spec.pickle_object_gcs_uri",
17741914
"spec.package_spec.requirements_gcs_uri",
1915+
"spec.class_methods",
17751916
"spec.deployment_spec.env",
17761917
"spec.deployment_spec.secret_env",
1777-
"spec.class_methods",
17781918
"spec.agent_framework",
17791919
]
17801920
)

vertexai/_genai/_agent_engines_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import abc
1818
import asyncio
19+
import base64
1920
from importlib import metadata as importlib_metadata
2021
import inspect
2122
import io
@@ -1161,6 +1162,21 @@ def _upload_extra_packages(
11611162
logger.info(f"Writing to {dir_name}/{_EXTRA_PACKAGES_FILE}")
11621163

11631164

1165+
def _create_base64_encoded_tarball(
1166+
*,
1167+
source_packages: Sequence[str],
1168+
) -> str:
1169+
"""Creates a base64 encoded tarball from the source packages."""
1170+
logger.info("Creating in-memory tarfile of source_packages")
1171+
tar_fileobj = io.BytesIO()
1172+
with tarfile.open(fileobj=tar_fileobj, mode="w|gz") as tar:
1173+
for file in source_packages:
1174+
tar.add(file)
1175+
tar_fileobj.seek(0)
1176+
tarball_bytes = tar_fileobj.read()
1177+
return base64.b64encode(tarball_bytes).decode("utf-8")
1178+
1179+
11641180
def _validate_extra_packages_or_raise(
11651181
*,
11661182
extra_packages: Sequence[str],

0 commit comments

Comments
 (0)