Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def unit(session: Session):
def lint(session: Session):
"""Enforce code style with flake8."""
session.install("-r", LINT_REQUIREMENTS)
session.run("flake8", "--config", ".flake8")
session.run(
"flake8", "--exclude", "venv,.venv,env,.nox,build,dist", "--config", ".flake8"
)
session.run("black", "--check", ".")
session.run("typos", "--config", "typos.toml", "-w")

Expand Down
6 changes: 6 additions & 0 deletions pai/api/api_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .experiment import ExperimentAPI
from .image import ImageAPI
from .job import JobAPI
from .lineage import LineageAPI
from .model import ModelAPI
from .pipeline import PipelineAPI
from .pipeline_run import PipelineRunAPI
Expand All @@ -49,6 +50,7 @@
PAIRestResourceTypes.PipelineRun: PipelineRunAPI,
PAIRestResourceTypes.TensorBoard: TensorBoardAPI,
PAIRestResourceTypes.Experiment: ExperimentAPI,
PAIRestResourceTypes.Lineage: LineageAPI,
}


Expand Down Expand Up @@ -217,3 +219,7 @@ def pipeline_run_api(self) -> PipelineRunAPI:
@property
def experiment_api(self) -> ExperimentAPI:
return self.get_api_by_resource(PAIRestResourceTypes.Experiment)

@property
def lineage_api(self) -> LineageAPI:
return self.get_api_by_resource(PAIRestResourceTypes.Lineage)
1 change: 1 addition & 0 deletions pai/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class PAIRestResourceTypes(object):
PipelineRun = "PipelineRun"
TensorBoard = "TensorBoard"
Experiment = "Experiment"
Lineage = "Lineage"


class ResourceAPI(with_metaclass(ABCMeta, object)):
Expand Down
76 changes: 76 additions & 0 deletions pai/api/lineage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional

from ..common.logging import get_logger
from ..libs.alibabacloud_aiworkspace20210204.models import (
LineageEntity,
RegisterLineageRequest,
)
from .base import ServiceName, WorkspaceScopedResourceAPI

logger = get_logger(__name__)


@dataclass
class _LineageEntity:
Attributes: Dict[str, str] = None
EntityType: Optional[str] = None
Name: Optional[str] = None
QualifiedName: Optional[str] = None


class LineageAPI(WorkspaceScopedResourceAPI):
BACKEND_SERVICE_NAME = ServiceName.PAI_WORKSPACE

_register_lineage = "register_lineage_with_options"

def log_lineage(
self,
inputs: List[_LineageEntity],
outputs: List[_LineageEntity],
job_id: str,
workspace_id: str,
):
input_entities = []
output_entities = []
for input in inputs:
input_entities.append(
LineageEntity(
attributes=input.Attributes,
entity_type=input.EntityType,
name=input.Name,
qualified_name=input.QualifiedName,
)
)
for output in outputs:
output_entities.append(
LineageEntity(
attributes=output.Attributes,
entity_type=output.EntityType,
name=output.Name,
qualified_name=output.QualifiedName,
)
)
request = RegisterLineageRequest(
register_task_as_entity=True,
input_entities=input_entities,
output_entities=output_entities,
qualified_name="pai_dlcjob-task." + job_id,
name=job_id,
attributes={"WorkspaceId": workspace_id},
)
response = self._do_request(method_=self._register_lineage, request=request)
logger.debug(response)
38 changes: 38 additions & 0 deletions pai/common/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,44 @@
# PAI VPC endpoint
PAI_VPC_ENDPOINT = "pai-vpc.{}.aliyuncs.com"

# All region list, https://help.aliyun.com/document_detail/40654.html
ALIYUN_ALL_REGION_ID_LIST = [
"cn-qingdao",
"cn-beijing",
"cn-zhangjiakou",
"cn-huhehaote",
"cn-wulanchabu",
"cn-hangzhou",
"cn-shanghai",
"cn-nanjing",
"cn-fuzhou",
"cn-wuhan-lr",
"cn-shenzhen",
"cn-heyuan",
"cn-guangzhou",
"cn-chengdu",
"cn-hongkong",
"ap-southeast-1",
"ap-southeast-2",
"ap-southeast-3",
"ap-southeast-5",
"ap-southeast-6",
"ap-southeast-7",
"ap-northeast-1",
"ap-northeast-2",
"eu-west-1",
"us-east-1",
"eu-central-1",
"eu-west-1",
"me-east-1",
"me-central-1",
"cn-hangzhou-finance",
"cn-shanghai-finance-1",
"cn-shenzhen-finance-1",
"cn-beijing-finance-1",
"cn-north-2-gov-1",
]


class Network(enum.Enum):
VPC = "VPC"
Expand Down
110 changes: 110 additions & 0 deletions pai/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from semantic_version import Version

from pai.common.consts import (
ALIYUN_ALL_REGION_ID_LIST,
INSTANCE_TYPE_LOCAL,
INSTANCE_TYPE_LOCAL_GPU,
FileSystemInputScheme,
Expand Down Expand Up @@ -379,3 +380,112 @@ def name_from_base(base_name: str, sep: str = "-") -> str:
return "{base_name}{sep}{timestamp}".format(
base_name=base_name, sep=sep, timestamp=timestamp(sep=sep, utc=False)
)


def parse_region_id_from_endpoint(endpoint) -> str:
if endpoint:
for region_id in ALIYUN_ALL_REGION_ID_LIST:
if region_id in endpoint:
return region_id
return None


def parse_oss_uri(uri):
if uri.startswith("oss://"):
match = re.match(r"^oss://([^./]+)\.([^./]+)\.aliyuncs\.com(?:/(.+))?", uri)
if not match:
warnings.warn("Invalid OSS URI format.")
return None
bucket_name, endpoint, path = match.groups()
region_id = parse_region_id_from_endpoint(endpoint)
if not region_id:
warnings.warn("Invalid OSS URI format.")
return None
return bucket_name, region_id, "/" if path is None else path
return None


def parse_nas_uri(uri):
if uri.startswith("nas://"):
match = re.match(r"^nas://([^./]+)\.([^/]+)(?:/(.+))?", uri)
if not match:
warnings.warn("Invalid NAS URI format.")
return None
endpoint = match.groups()[1]
region_id = parse_region_id_from_endpoint(endpoint)
if not region_id:
warnings.warn("Invalid NAS URI format.")
return None
return uri, region_id
return None


def parse_cpfs_uri(uri):
if uri.startswith("cpfs://"):
match = re.match(r"^cpfs://([^./]+)\.([^/]+)(?:/(.+))?", uri)
if not match:
warnings.warn("Invalid CPFS URI format.")
return None
endpoint = match.groups()[1]
region_id = parse_region_id_from_endpoint(endpoint)
if not region_id:
warnings.warn("Invalid CPFS URI format.")
return None
return uri, region_id
return None


def parse_bmcpfs_uri(uri):
if uri.startswith("bmcpfs://"):
match = re.match(r"^bmcpfs://([^./]+)\.([^/]+)(?:/(.+))?", uri)
if not match:
warnings.warn("Invalid BMCPFS URI format.")
return None
endpoint = match.groups()[1]
region_id = parse_region_id_from_endpoint(endpoint)
if not region_id:
warnings.warn("Invalid BMCPFS URI format.")
return None
return uri, region_id
return None


def parse_local_file_uri(uri):
if uri.startswith("file:///"):
match = re.match(r"^file://(.+)", uri)
if not match:
warnings.warn("Invalid local file URI format.")
return None
return match.group(1)
return None


def parse_pai_dataset_uri(uri):
if uri.startswith("pai://datasets"):
match = re.match(r"^pai://datasets/([^/]+)(?:/(.+))?", uri)
if not match:
warnings.warn("Invalid PAI dataset URI format.")
return None
dataset_id, dataset_version = match.groups()
dataset_version = dataset_version if dataset_version else "1"
dataset_version = (
dataset_version.split("/")[0] if "/" in dataset_version else dataset_version
)
return dataset_id, dataset_version
return None


def parse_odps_uri(uri):
if uri.startswith("odps://"):
match = re.match(r"^odps://(.+)/tables/(.+)", uri)
if not match:
warnings.warn("Invalid MaxCompute URI format.")
return None
project_and_schema, table_name = match.groups()
project_name, schema = (
project_and_schema.split("/")
if "/" in project_and_schema
else (project_and_schema, None)
)
return project_name, schema, table_name
return None
2 changes: 1 addition & 1 deletion pai/libs/alibabacloud_aiworkspace20210204/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '3.0.6'
__version__ = '5.0.1'
Loading
Loading