Skip to content

Commit d0c7803

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add new api for getting deploy options for Vertex AI Model Garden custom model
PiperOrigin-RevId: 796997464
1 parent 65bf9b6 commit d0c7803

File tree

2 files changed

+271
-8
lines changed

2 files changed

+271
-8
lines changed

tests/unit/vertexai/model_garden/test_model_garden.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,6 +1476,171 @@ def test_deploy_custom_model_with_all_config_success(self, deploy_mock):
14761476
)
14771477
)
14781478

1479+
def test_list_deploy_options_with_recommendations(self):
1480+
"""Tests list_deploy_options when recommend_spec returns recommendations."""
1481+
aiplatform.init(
1482+
project=_TEST_PROJECT,
1483+
location=_TEST_LOCATION,
1484+
)
1485+
mock_model_service_client = mock.Mock()
1486+
with mock.patch.object(
1487+
aiplatform.initializer.global_config,
1488+
"create_client",
1489+
return_value=mock_model_service_client,
1490+
):
1491+
quota_state = types.RecommendSpecResponse.Recommendation.QuotaState
1492+
mock_response = types.RecommendSpecResponse(
1493+
recommendations=[
1494+
types.RecommendSpecResponse.Recommendation(
1495+
spec=types.RecommendSpecResponse.MachineAndModelContainerSpec(
1496+
machine_spec=types.MachineSpec(
1497+
machine_type="n1-standard-4",
1498+
accelerator_type=types.AcceleratorType.NVIDIA_TESLA_T4,
1499+
accelerator_count=1,
1500+
)
1501+
),
1502+
region="us-central1",
1503+
user_quota_state=quota_state.QUOTA_STATE_USER_HAS_QUOTA,
1504+
),
1505+
types.RecommendSpecResponse.Recommendation(
1506+
spec=types.RecommendSpecResponse.MachineAndModelContainerSpec(
1507+
machine_spec=types.MachineSpec(
1508+
machine_type="n1-standard-8",
1509+
accelerator_type=types.AcceleratorType.NVIDIA_TESLA_V100,
1510+
accelerator_count=2,
1511+
)
1512+
),
1513+
region="us-east1",
1514+
user_quota_state=quota_state.QUOTA_STATE_NO_USER_QUOTA,
1515+
),
1516+
types.RecommendSpecResponse.Recommendation(
1517+
spec=types.RecommendSpecResponse.MachineAndModelContainerSpec(
1518+
machine_spec=types.MachineSpec(
1519+
machine_type="g2-standard-24",
1520+
accelerator_type=types.AcceleratorType.NVIDIA_L4,
1521+
accelerator_count=2,
1522+
)
1523+
),
1524+
region="us-central1",
1525+
user_quota_state=quota_state.QUOTA_STATE_UNSPECIFIED,
1526+
),
1527+
]
1528+
)
1529+
mock_model_service_client.recommend_spec.return_value = mock_response
1530+
1531+
custom_model = model_garden_preview.CustomModel(gcs_uri=_TEST_GCS_URI)
1532+
result = custom_model.list_deploy_options()
1533+
1534+
expected_output = textwrap.dedent(
1535+
"""\
1536+
[Option 1]
1537+
machine_type="n1-standard-4",
1538+
accelerator_type="NVIDIA_TESLA_T4",
1539+
accelerator_count=1,
1540+
region="us-central1",
1541+
user_quota_state="QUOTA_STATE_USER_HAS_QUOTA"
1542+
1543+
[Option 2]
1544+
machine_type="n1-standard-8",
1545+
accelerator_type="NVIDIA_TESLA_V100",
1546+
accelerator_count=2,
1547+
region="us-east1",
1548+
user_quota_state="QUOTA_STATE_NO_USER_QUOTA"
1549+
1550+
[Option 3]
1551+
machine_type="g2-standard-24",
1552+
accelerator_type="NVIDIA_L4",
1553+
accelerator_count=2,
1554+
region="us-central1\""""
1555+
)
1556+
assert result == expected_output
1557+
mock_model_service_client.recommend_spec.assert_called_once_with(
1558+
types.RecommendSpecRequest(
1559+
gcs_uri=_TEST_GCS_URI,
1560+
parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
1561+
check_machine_availability=True,
1562+
),
1563+
timeout=60,
1564+
)
1565+
1566+
def test_list_deploy_options_with_specs(self):
1567+
"""Tests list_deploy_options when recommend_spec returns specs."""
1568+
aiplatform.init(
1569+
project=_TEST_PROJECT,
1570+
location=_TEST_LOCATION,
1571+
)
1572+
mock_model_service_client = mock.Mock()
1573+
with mock.patch.object(
1574+
aiplatform.initializer.global_config,
1575+
"create_client",
1576+
return_value=mock_model_service_client,
1577+
):
1578+
mock_response = types.RecommendSpecResponse(
1579+
specs=[
1580+
types.RecommendSpecResponse.MachineAndModelContainerSpec(
1581+
machine_spec=types.MachineSpec(
1582+
machine_type="n1-standard-4",
1583+
accelerator_type=types.AcceleratorType.NVIDIA_TESLA_T4,
1584+
accelerator_count=1,
1585+
)
1586+
),
1587+
types.RecommendSpecResponse.MachineAndModelContainerSpec(
1588+
machine_spec=types.MachineSpec(
1589+
machine_type="n1-standard-8",
1590+
accelerator_type=types.AcceleratorType.NVIDIA_TESLA_V100,
1591+
accelerator_count=2,
1592+
)
1593+
),
1594+
]
1595+
)
1596+
mock_model_service_client.recommend_spec.return_value = mock_response
1597+
1598+
custom_model = model_garden_preview.CustomModel(gcs_uri=_TEST_GCS_URI)
1599+
result = custom_model.list_deploy_options(available_machines=False)
1600+
1601+
expected_output = textwrap.dedent(
1602+
"""\
1603+
[Option 1]
1604+
machine_type="n1-standard-4",
1605+
accelerator_type="NVIDIA_TESLA_T4",
1606+
accelerator_count=1
1607+
1608+
[Option 2]
1609+
machine_type="n1-standard-8",
1610+
accelerator_type="NVIDIA_TESLA_V100",
1611+
accelerator_count=2"""
1612+
)
1613+
assert result == expected_output
1614+
mock_model_service_client.recommend_spec.assert_called_once_with(
1615+
types.RecommendSpecRequest(
1616+
gcs_uri=_TEST_GCS_URI,
1617+
parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
1618+
check_machine_availability=False,
1619+
),
1620+
timeout=60,
1621+
)
1622+
1623+
def test_list_deploy_options_exception(self):
1624+
"""Tests list_deploy_options when recommend_spec raises an exception."""
1625+
aiplatform.init(
1626+
project=_TEST_PROJECT,
1627+
location=_TEST_LOCATION,
1628+
)
1629+
mock_model_service_client = mock.Mock()
1630+
with mock.patch.object(
1631+
aiplatform.initializer.global_config,
1632+
"create_client",
1633+
return_value=mock_model_service_client,
1634+
):
1635+
mock_model_service_client.recommend_spec.side_effect = ValueError(
1636+
"Test Error"
1637+
)
1638+
custom_model = model_garden_preview.CustomModel(gcs_uri=_TEST_GCS_URI)
1639+
with pytest.raises(ValueError) as exception:
1640+
custom_model.list_deploy_options()
1641+
assert str(exception.value) == "Test Error"
1642+
mock_model_service_client.recommend_spec.assert_called_once()
1643+
14791644

14801645
class TestModelGardenModel:
14811646
"""Test cases for Model Garden Model class."""

vertexai/model_garden/_model_garden.py

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from google.cloud.aiplatform import utils
3030
from google.cloud.aiplatform_v1beta1 import types
3131
from google.cloud.aiplatform_v1beta1.services import model_garden_service
32+
from google.cloud.aiplatform_v1beta1.services import model_service
3233
from vertexai import batch_prediction
3334

3435

@@ -38,6 +39,7 @@
3839
_LOGGER = base.Logger(__name__)
3940
_DEFAULT_VERSION = compat.V1BETA1
4041
_DEFAULT_TIMEOUT = 2 * 60 * 60 # 2 hours, same as UI one-click deployment.
42+
_DEFAULT_RECOMMEND_SPEC_TIMEOUT = 1 * 60 # 1 minute.
4143
_DEFAULT_EXPORT_TIMEOUT = 1 * 60 * 60 # 1 hour.
4244
_HF_WILDCARD_FILTER = "is_hf_wildcard(true)"
4345
_NATIVE_MODEL_FILTER = "is_hf_wildcard(false)"
@@ -256,6 +258,17 @@ class _ModelGardenClientWithOverride(utils.ClientWithOverride):
256258
)
257259

258260

261+
class _ModelServiceClientWithOverride(utils.ClientWithOverride):
262+
_is_temporary = True
263+
_default_version = _DEFAULT_VERSION
264+
_version_map = (
265+
(
266+
_DEFAULT_VERSION,
267+
model_service.ModelServiceClient,
268+
),
269+
)
270+
271+
259272
class OpenModel:
260273
"""Represents a Model Garden Open model.
261274
@@ -860,6 +873,91 @@ def _model_garden_client(
860873
location_override=self._location,
861874
)
862875

876+
@functools.cached_property
877+
def _model_service_client(
878+
self,
879+
) -> model_service.ModelServiceClient:
880+
"""Returns the Model Service client."""
881+
return initializer.global_config.create_client(
882+
client_class=_ModelServiceClientWithOverride,
883+
credentials=self._credentials,
884+
location_override=self._location,
885+
)
886+
887+
def list_deploy_options(
888+
self,
889+
available_machines: bool = True,
890+
request_timeout: Optional[float] = None,
891+
) -> str:
892+
"""Lists the deploy options for the model.
893+
894+
Args:
895+
available_machines: If true, only return the deploy options for
896+
available machines.
897+
request_timeout: The timeout for the recommend spec request.
898+
Default is 60 seconds.
899+
900+
Returns:
901+
str: A string of the deploy options represented by
902+
machine spec and container spec.
903+
904+
"""
905+
906+
def _extract_spec(spec):
907+
machine_spec = spec.machine_spec
908+
return {
909+
"machine_type": getattr(machine_spec, "machine_type", None),
910+
"accelerator_type": getattr(
911+
getattr(machine_spec, "accelerator_type", None), "name", None
912+
),
913+
"accelerator_count": getattr(machine_spec, "accelerator_count", None),
914+
}
915+
916+
def _extract_recommendation(recommendation):
917+
extracted_spec = _extract_spec(recommendation.spec)
918+
extracted_spec["region"] = getattr(recommendation, "region", None)
919+
if (
920+
recommendation.user_quota_state
921+
and recommendation.user_quota_state
922+
!= types.RecommendSpecResponse.Recommendation.QuotaState.QUOTA_STATE_UNSPECIFIED
923+
):
924+
extracted_spec["user_quota_state"] = getattr(
925+
getattr(recommendation, "user_quota_state", None), "name", None
926+
)
927+
return extracted_spec
928+
929+
request = types.RecommendSpecRequest(
930+
gcs_uri=self._gcs_uri,
931+
parent=f"projects/{self._project}/locations/{self._location}",
932+
check_machine_availability=available_machines,
933+
)
934+
try:
935+
response = self._model_service_client.recommend_spec(
936+
request, timeout=request_timeout or _DEFAULT_RECOMMEND_SPEC_TIMEOUT
937+
)
938+
options = []
939+
if response.recommendations:
940+
options = [
941+
_extract_recommendation(recommendation)
942+
for recommendation in response.recommendations
943+
if recommendation.spec
944+
]
945+
elif response.specs:
946+
options = [_extract_spec(spec) for spec in response.specs if spec]
947+
return "\n\n".join(
948+
f"[Option {i + 1}]\n"
949+
+ ",\n".join(
950+
f' {k}="{v}"' if k != "accelerator_count" else f" {k}={v}"
951+
for k, v in config.items()
952+
if v is not None
953+
)
954+
for i, config in enumerate(options)
955+
)
956+
957+
except Exception as e:
958+
_LOGGER.error(f"Failed to list deploy options: {e}")
959+
raise e
960+
863961
def deploy(
864962
self,
865963
machine_type: Optional[str] = None,
@@ -906,14 +1004,14 @@ def deploy(
9061004
Created endpoint.
9071005
"""
9081006
return self._deploy_gcs_uri(
909-
machine_type,
910-
min_replica_count,
911-
max_replica_count,
912-
accelerator_type,
913-
accelerator_count,
914-
endpoint_display_name,
915-
model_display_name,
916-
deploy_request_timeout,
1007+
machine_type=machine_type,
1008+
min_replica_count=min_replica_count,
1009+
max_replica_count=max_replica_count,
1010+
accelerator_type=accelerator_type,
1011+
accelerator_count=accelerator_count,
1012+
endpoint_display_name=endpoint_display_name,
1013+
model_display_name=model_display_name,
1014+
deploy_request_timeout=deploy_request_timeout,
9171015
)
9181016

9191017
def _deploy_model_registry_model(self) -> aiplatform.Endpoint:

0 commit comments

Comments
 (0)