diff --git a/.stats.yml b/.stats.yml index a3e1592f..edc1c05a 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 55 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-51627d25c5c4ea3cf03c92a335acf66cf8cad652079915109fe9711a57f7e003.yml -openapi_spec_hash: 97f97a89965aa05900566ca2824a4de1 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-28d95f054f8ffe846f42a014a1c86d0385604e9d05850c17819ef986a068ac88.yml +openapi_spec_hash: 83d5ac256007e9ccd40abe11a5983168 config_hash: 6acd26f13abe2b4550fb4bbb06d31523 diff --git a/examples/tokenize_data.py b/examples/tokenize_data.py new file mode 100644 index 00000000..1b9d0035 --- /dev/null +++ b/examples/tokenize_data.py @@ -0,0 +1,222 @@ +import logging +import argparse +from typing import Dict, List +from functools import partial +from multiprocessing import cpu_count + +from datasets import Dataset, load_dataset # type: ignore +from transformers import ( # type: ignore + AutoTokenizer, + BatchEncoding, + PreTrainedTokenizerBase, +) + +# see default of ignore_index +# for https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss +LOSS_IGNORE_INDEX = -100 + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def tokenize_variable_length( + data: Dict[str, str], + tokenizer: PreTrainedTokenizerBase, + add_special_tokens: bool = True, +) -> BatchEncoding: + tokenized = tokenizer(data["text"], add_special_tokens=add_special_tokens, truncation=False) + return tokenized + + +def tokenize_constant_length( + data: Dict[str, str], + tokenizer: PreTrainedTokenizerBase, + max_length: int = 2048, + add_special_tokens: bool = True, + add_labels: bool = True, +) -> BatchEncoding: + # tokenized contains `input_ids` and `attention_mask` + tokenized: BatchEncoding = tokenizer( + data["text"], + max_length=max_length, + truncation=True, + padding="max_length", + add_special_tokens=add_special_tokens, + ) + # add labels to mask out any padding tokens + if add_labels: + tokenized["labels"] = [ + LOSS_IGNORE_INDEX if token_id == tokenizer.pad_token_id else token_id for token_id in tokenized["input_ids"] + ] + + return tokenized + + +def pack_sequences( + batch: BatchEncoding, + max_seq_len: int, + pad_token_id: int, + eos_token_id: int, + add_labels: bool, + cutoff_size: int = 0, +) -> Dict[str, List[List[int]]]: + """ + cutoff_size = max_seq_len means that we will drop any non-full sequences + (full packing without padding) + Example: + Sequence 1: + ['', '▁usually', '▁,', '▁he', '▁would', '▁be', '▁t', 'earing'] + Sequence 2: + ['▁around', '▁the', '▁living', '▁room', '▁,', '▁playing', '▁with', '▁his'] + Sequence 3: + ['▁toys', '▁.', '', '', '▁but', '▁just', '▁one', '▁look'] + """ + packed_sequences = [] + buffer = [] + + for input_ids in batch["input_ids"]: + # Add the current sequence to the buffer + buffer.extend(input_ids) + buffer.append(eos_token_id) # Add EOS at the end of each sequence + + # Check if buffer needs to be split into chunks + while len(buffer) > max_seq_len: + # Take a full chunk from the buffer and append it to packed_sequences + packed_sequences.append(buffer[:max_seq_len]) + # Remove the processed chunk from the buffer + buffer = buffer[max_seq_len:] + + # Add the last buffer if it's exactly chunk_size + if len(buffer) == max_seq_len: + packed_sequences.append(buffer) + elif len(buffer) > cutoff_size: + # if the buffer is larger than the cutoff size, pad it to the chunk_size + # if not, we do not include in the packed_sequences + buffer.extend([pad_token_id] * (max_seq_len - len(buffer))) + packed_sequences.append(buffer) + + output = {"input_ids": packed_sequences} + if add_labels: + output["labels"] = [ + [LOSS_IGNORE_INDEX if token_id == pad_token_id else token_id for token_id in example] + for example in output["input_ids"] + ] + + # mask attention for padding tokens, a better version would also mask cross-sequence dependencies + output["attention_mask"] = [ + [0 if token_id == pad_token_id else 1 for token_id in example] for example in output["input_ids"] + ] + return output + + +def process_fast_packing( + dataset: Dataset, + tokenizer: PreTrainedTokenizerBase, + max_sequence_length: int, + add_labels: bool, + add_special_tokens: bool, +) -> Dataset: + tokenized_dataset = dataset.map( + lambda examples: tokenize_variable_length(examples, tokenizer, add_special_tokens=add_special_tokens), + batched=True, + num_proc=cpu_count(), + load_from_cache_file=True, + remove_columns=dataset.column_names, + ) + logger.info(f"tokenized dataset: {tokenized_dataset}") + + packed_dataset = tokenized_dataset.map( + lambda batch: pack_sequences( + batch, + max_sequence_length, + tokenizer.pad_token_id, + tokenizer.eos_token_id, + add_labels=add_labels, + cutoff_size=max_sequence_length, + ), + batched=True, + num_proc=cpu_count() if len(tokenized_dataset) > 10000 else 1, + remove_columns=["attention_mask"], + ) + logger.info(f"Packed dataset: {packed_dataset}") + return packed_dataset + + +def process_data(args: argparse.Namespace) -> None: + if not args.out_filename.endswith(".parquet"): + raise ValueError("`--out_filename` should have the `.parquet` extension") + + dataset = load_dataset(args.dataset, split="train") + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + tokenizer.pad_token = tokenizer.eos_token + + dataset.to_json("dataset.jsonl", orient="records", lines=True) + + if not args.packing: + tokenized_data = dataset.map( + partial( + tokenize_constant_length, + tokenizer=tokenizer, + max_length=args.max_seq_length, + add_special_tokens=True, + add_labels=args.add_labels, + ), + batched=False, + num_proc=cpu_count(), + remove_columns=dataset.column_names, + ) + else: + tokenized_data = process_fast_packing( + dataset, + tokenizer, + max_sequence_length=args.max_seq_length, + add_labels=args.add_labels, + add_special_tokens=True, + ) + + assert "input_ids" in tokenized_data.column_names and "attention_mask" in tokenized_data.column_names + + if args.add_labels: + assert "labels" in tokenized_data.column_names + + logger.info("Tokenized data:") + print(tokenized_data) + + logger.info(f"Saving data to {args.out_filename}") + print(len(tokenized_data[0]["input_ids"])) + tokenized_data.to_parquet(args.out_filename) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Pretokenize examples for finetuning via Together") + parser.add_argument( + "--dataset", + type=str, + default="clam004/antihallucination_dataset", + help="Dataset name on the Hugging Face Hub", + ) + parser.add_argument("--max-seq-length", type=int, default=8192, help="Maximum sequence length") + parser.add_argument( + "--add-labels", + action="store_true", + help="Whether to add loss labels from padding tokens", + ) + parser.add_argument( + "--tokenizer", + type=str, + required=True, + help="Tokenizer name (for example, togethercomputer/Llama-3-8b-hf)", + ) + parser.add_argument( + "--out-filename", + default="processed_dataset.parquet", + help="Name of the Parquet file to save (should have .parquet extension)", + ) + parser.add_argument( + "--packing", + action="store_true", + help="Whether to pack shorter sequences up to `--max-seq-length`", + ) + args = parser.parse_args() + + process_data(args) diff --git a/pyproject.toml b/pyproject.toml index c90ff583..2cbfb713 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,14 +15,15 @@ dependencies = [ "anyio>=3.5.0, <5", "distro>=1.7.0, <2", "sniffio", - "click>=8.1.7", - "rich>=13.7.1", - "tabulate>=0.9.0", - "pillow>=10.4.0", - "types-tabulate>=0.9.0.20240106", - "tqdm>=4.67.1", - "types-tqdm>=4.67.0.20250516", - "filelock>=3.13.1", + "click>=8.1.7", + "rich>=13.7.1", + "tabulate>=0.9.0", + "pillow>=10.4.0", + "types-tabulate>=0.9.0.20240106", + "tqdm>=4.67.1", + "types-tqdm>=4.67.0.20250516", + "filelock>=3.13.1", + "py-machineid>=1.0.0", ] requires-python = ">= 3.9" @@ -150,6 +151,7 @@ exclude = [ ".venv", ".nox", ".git", + "examples", ] reportImplicitOverride = true diff --git a/src/together/_base_client.py b/src/together/_base_client.py index 3a631d34..9838c526 100644 --- a/src/together/_base_client.py +++ b/src/together/_base_client.py @@ -9,6 +9,7 @@ import inspect import logging import platform +import warnings import email.utils from types import TracebackType from random import random @@ -51,9 +52,11 @@ ResponseT, AnyMapping, PostParser, + BinaryTypes, RequestFiles, HttpxSendArgs, RequestOptions, + AsyncBinaryTypes, HttpxRequestFiles, ModelBuilderProtocol, not_given, @@ -477,8 +480,19 @@ def _build_request( retries_taken: int = 0, ) -> httpx.Request: if log.isEnabledFor(logging.DEBUG): - log.debug("Request options: %s", model_dump(options, exclude_unset=True)) - + log.debug( + "Request options: %s", + model_dump( + options, + exclude_unset=True, + # Pydantic v1 can't dump every type we support in content, so we exclude it for now. + exclude={ + "content", + } + if PYDANTIC_V1 + else {}, + ), + ) kwargs: dict[str, Any] = {} json_data = options.json_data @@ -532,7 +546,13 @@ def _build_request( is_body_allowed = options.method.lower() != "get" if is_body_allowed: - if isinstance(json_data, bytes): + if options.content is not None and json_data is not None: + raise TypeError("Passing both `content` and `json_data` is not supported") + if options.content is not None and files is not None: + raise TypeError("Passing both `content` and `files` is not supported") + if options.content is not None: + kwargs["content"] = options.content + elif isinstance(json_data, bytes): kwargs["content"] = json_data else: kwargs["json"] = json_data if is_given(json_data) else None @@ -1194,6 +1214,7 @@ def post( *, cast_to: Type[ResponseT], body: Body | None = None, + content: BinaryTypes | None = None, options: RequestOptions = {}, files: RequestFiles | None = None, stream: Literal[False] = False, @@ -1206,6 +1227,7 @@ def post( *, cast_to: Type[ResponseT], body: Body | None = None, + content: BinaryTypes | None = None, options: RequestOptions = {}, files: RequestFiles | None = None, stream: Literal[True], @@ -1219,6 +1241,7 @@ def post( *, cast_to: Type[ResponseT], body: Body | None = None, + content: BinaryTypes | None = None, options: RequestOptions = {}, files: RequestFiles | None = None, stream: bool, @@ -1231,13 +1254,25 @@ def post( *, cast_to: Type[ResponseT], body: Body | None = None, + content: BinaryTypes | None = None, options: RequestOptions = {}, files: RequestFiles | None = None, stream: bool = False, stream_cls: type[_StreamT] | None = None, ) -> ResponseT | _StreamT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) opts = FinalRequestOptions.construct( - method="post", url=path, json_data=body, files=to_httpx_files(files), **options + method="post", url=path, json_data=body, content=content, files=to_httpx_files(files), **options ) return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)) @@ -1247,11 +1282,23 @@ def patch( *, cast_to: Type[ResponseT], body: Body | None = None, + content: BinaryTypes | None = None, files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) opts = FinalRequestOptions.construct( - method="patch", url=path, json_data=body, files=to_httpx_files(files), **options + method="patch", url=path, json_data=body, content=content, files=to_httpx_files(files), **options ) return self.request(cast_to, opts) @@ -1261,11 +1308,23 @@ def put( *, cast_to: Type[ResponseT], body: Body | None = None, + content: BinaryTypes | None = None, files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) opts = FinalRequestOptions.construct( - method="put", url=path, json_data=body, files=to_httpx_files(files), **options + method="put", url=path, json_data=body, content=content, files=to_httpx_files(files), **options ) return self.request(cast_to, opts) @@ -1275,9 +1334,19 @@ def delete( *, cast_to: Type[ResponseT], body: Body | None = None, + content: BinaryTypes | None = None, options: RequestOptions = {}, ) -> ResponseT: - opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options) + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, content=content, **options) return self.request(cast_to, opts) def get_api_list( @@ -1717,6 +1786,7 @@ async def post( *, cast_to: Type[ResponseT], body: Body | None = None, + content: AsyncBinaryTypes | None = None, files: RequestFiles | None = None, options: RequestOptions = {}, stream: Literal[False] = False, @@ -1729,6 +1799,7 @@ async def post( *, cast_to: Type[ResponseT], body: Body | None = None, + content: AsyncBinaryTypes | None = None, files: RequestFiles | None = None, options: RequestOptions = {}, stream: Literal[True], @@ -1742,6 +1813,7 @@ async def post( *, cast_to: Type[ResponseT], body: Body | None = None, + content: AsyncBinaryTypes | None = None, files: RequestFiles | None = None, options: RequestOptions = {}, stream: bool, @@ -1754,13 +1826,25 @@ async def post( *, cast_to: Type[ResponseT], body: Body | None = None, + content: AsyncBinaryTypes | None = None, files: RequestFiles | None = None, options: RequestOptions = {}, stream: bool = False, stream_cls: type[_AsyncStreamT] | None = None, ) -> ResponseT | _AsyncStreamT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) opts = FinalRequestOptions.construct( - method="post", url=path, json_data=body, files=await async_to_httpx_files(files), **options + method="post", url=path, json_data=body, content=content, files=await async_to_httpx_files(files), **options ) return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls) @@ -1770,11 +1854,28 @@ async def patch( *, cast_to: Type[ResponseT], body: Body | None = None, + content: AsyncBinaryTypes | None = None, files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) opts = FinalRequestOptions.construct( - method="patch", url=path, json_data=body, files=await async_to_httpx_files(files), **options + method="patch", + url=path, + json_data=body, + content=content, + files=await async_to_httpx_files(files), + **options, ) return await self.request(cast_to, opts) @@ -1784,11 +1885,23 @@ async def put( *, cast_to: Type[ResponseT], body: Body | None = None, + content: AsyncBinaryTypes | None = None, files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) opts = FinalRequestOptions.construct( - method="put", url=path, json_data=body, files=await async_to_httpx_files(files), **options + method="put", url=path, json_data=body, content=content, files=await async_to_httpx_files(files), **options ) return await self.request(cast_to, opts) @@ -1798,9 +1911,19 @@ async def delete( *, cast_to: Type[ResponseT], body: Body | None = None, + content: AsyncBinaryTypes | None = None, options: RequestOptions = {}, ) -> ResponseT: - opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options) + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, content=content, **options) return await self.request(cast_to, opts) def get_api_list( diff --git a/src/together/_client.py b/src/together/_client.py index 452d9387..aa85275c 100644 --- a/src/together/_client.py +++ b/src/together/_client.py @@ -3,11 +3,14 @@ from __future__ import annotations import os +import sys from typing import TYPE_CHECKING, Any, Mapping from typing_extensions import Self, override import httpx +from together.lib._google_colab import get_google_colab_secret + from . import _exceptions from ._qs import Querystring from ._types import ( @@ -113,6 +116,8 @@ def __init__( """ if api_key is None: api_key = os.environ.get("TOGETHER_API_KEY") + if api_key is None and "google.colab" in sys.modules: + api_key = get_google_colab_secret("TOGETHER_API_KEY") if api_key is None: raise TogetherError( "The api_key client option must be set either by passing api_key to the client or by setting the TOGETHER_API_KEY environment variable" @@ -388,6 +393,8 @@ def __init__( """ if api_key is None: api_key = os.environ.get("TOGETHER_API_KEY") + if api_key is None and "google.colab" in sys.modules: + api_key = get_google_colab_secret("TOGETHER_API_KEY") if api_key is None: raise TogetherError( "The api_key client option must be set either by passing api_key to the client or by setting the TOGETHER_API_KEY environment variable" diff --git a/src/together/_models.py b/src/together/_models.py index ca9500b2..29070e05 100644 --- a/src/together/_models.py +++ b/src/together/_models.py @@ -3,7 +3,20 @@ import os import inspect import weakref -from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast +from typing import ( + IO, + TYPE_CHECKING, + Any, + Type, + Union, + Generic, + TypeVar, + Callable, + Iterable, + Optional, + AsyncIterable, + cast, +) from datetime import date, datetime from typing_extensions import ( List, @@ -787,6 +800,7 @@ class FinalRequestOptionsInput(TypedDict, total=False): timeout: float | Timeout | None files: HttpxRequestFiles | None idempotency_key: str + content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None] json_data: Body extra_json: AnyMapping follow_redirects: bool @@ -805,6 +819,7 @@ class FinalRequestOptions(pydantic.BaseModel): post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() follow_redirects: Union[bool, None] = None + content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None] = None # It should be noted that we cannot use `json` here as that would override # a BaseModel method in an incompatible fashion. json_data: Union[Body, None] = None diff --git a/src/together/_types.py b/src/together/_types.py index a39b8518..cf3a156f 100644 --- a/src/together/_types.py +++ b/src/together/_types.py @@ -13,9 +13,11 @@ Mapping, TypeVar, Callable, + Iterable, Iterator, Optional, Sequence, + AsyncIterable, ) from typing_extensions import ( Set, @@ -56,6 +58,13 @@ else: Base64FileInput = Union[IO[bytes], PathLike] FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8. + + +# Used for sending raw binary data / streaming data in request bodies +# e.g. for file uploads without multipart encoding +BinaryTypes = Union[bytes, bytearray, IO[bytes], Iterable[bytes]] +AsyncBinaryTypes = Union[bytes, bytearray, IO[bytes], AsyncIterable[bytes]] + FileTypes = Union[ # file (or bytes) FileContent, diff --git a/src/together/constants.py b/src/together/constants.py index 78fafa3b..cd13f59c 100644 --- a/src/together/constants.py +++ b/src/together/constants.py @@ -31,4 +31,4 @@ MAX_CONNECTION_RETRIES = 2 MAX_RETRIES = DEFAULT_MAX_RETRIES -BASE_URL = "https://api.together.xyz/v1" \ No newline at end of file +BASE_URL = "https://api.together.xyz/v1" diff --git a/src/together/error.py b/src/together/error.py index ebffa4e9..ad3e17ea 100644 --- a/src/together/error.py +++ b/src/together/error.py @@ -1,4 +1,3 @@ - # Manually added to minimize breaking changes from V1 from ._exceptions import ( APIError as APIError, @@ -13,4 +12,4 @@ Timeout = APITimeoutError InvalidRequestError = BadRequestError TogetherException = APIError -ResponseError = APIResponseValidationError \ No newline at end of file +ResponseError = APIResponseValidationError diff --git a/src/together/lib/_google_colab.py b/src/together/lib/_google_colab.py new file mode 100644 index 00000000..55026277 --- /dev/null +++ b/src/together/lib/_google_colab.py @@ -0,0 +1,39 @@ +import sys +from typing import Union + +from together.lib.utils._log import log_info + + +def get_google_colab_secret(secret_name: str = "TOGETHER_API_KEY") -> Union[str, None]: + """ + Checks to see if the user is running in Google Colab, and looks for the Together API Key secret. + + Args: + secret_name (str, optional). Defaults to TOGETHER_API_KEY + + Returns: + str: if the API key is found; None if an error occurred or the secret was not found. + """ + # If running in Google Colab, check for Together in notebook secrets + if "google.colab" in sys.modules: + from google.colab import userdata # type: ignore + + try: + api_key = userdata.get(secret_name) # type: ignore + if not isinstance(api_key, str): + return None + else: + return str(api_key) + except userdata.NotebookAccessError: # type: ignore + log_info( + "The TOGETHER_API_KEY Colab secret was found, but notebook access is disabled. Please enable notebook " + "access for the secret." + ) + except userdata.SecretNotFoundError: # type: ignore + # warn and carry on + log_info("Colab: No Google Colab secret named TOGETHER_API_KEY was found.") + + return None + + else: + return None diff --git a/src/together/lib/cli/_track_cli.py b/src/together/lib/cli/_track_cli.py new file mode 100644 index 00000000..4886b418 --- /dev/null +++ b/src/together/lib/cli/_track_cli.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import os +import json +import time +import uuid +import threading +from enum import Enum +from typing import Any, TypeVar, Callable +from functools import wraps + +import click +import httpx +import machineid + +from together import __version__ +from together.lib.utils import log_debug + +F = TypeVar("F", bound=Callable[..., Any]) + +SESSION_ID = int(str(uuid.uuid4().int)[0:13]) + + +def is_tracking_enabled() -> bool: + # Users can opt-out of tracking with the environment variable. + if os.getenv("TOGETHER_TELEMETRY_DISABLED"): + log_debug("Analytics tracking disabled by environment variable") + return False + + return True + + +class CliTrackingEvents(Enum): + CommandStarted = "cli_command_started" + CommandCompleted = "cli_commmand_completed" + CommandFailed = "cli_command_failed" + CommandUserAborted = "cli_command_user_aborted" + ApiRequest = "cli_command_api_request" + + +def track_cli(event_name: CliTrackingEvents, args: dict[str, Any]) -> None: + """Track a CLI event. Non-Blocking.""" + if is_tracking_enabled() == False: + return + + def send_event() -> None: + ANALYTICS_API_ENV_VAR = os.getenv("TOGETHER_TELEMETRY_API") + ANALYTICS_API = ( + ANALYTICS_API_ENV_VAR if ANALYTICS_API_ENV_VAR else "https://api.together.ai/api/together-cli-events" + ) + + try: + client = httpx.Client() + client.post( + ANALYTICS_API, + headers={"content-type": "application/json", "user-agent": f"together-cli:{__version__}"}, + content=json.dumps( + { + "event_name": event_name.value, + "event_properties": { + "is_ci": os.getenv("CI") is not None, + **args, + }, + "event_options": { + "time": int(time.time() * 1000), + "session_id": str(SESSION_ID), + "device_id": machineid.id().lower(), + }, + } + ), + ) + except Exception as e: + log_debug("Error sending analytics event", error=e) + # No-op - this is not critical and we don't want to block the CLI + pass + + threading.Thread(target=send_event).start() + + +def auto_track_command(command: str) -> Callable[[F], F]: + """Decorator for click commands to automatically track CLI commands start/completion/failure.""" + + def decorator(f: F) -> F: + @wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> Any: + track_cli(CliTrackingEvents.CommandStarted, {"command": command, "arguments": kwargs}) + try: + return f(*args, **kwargs) + except click.Abort: + # Doesn't seem like this is working any more + track_cli( + CliTrackingEvents.CommandUserAborted, + {"command": command, "arguments": kwargs}, + ) + except Exception as e: + track_cli(CliTrackingEvents.CommandFailed, {"command": command, "arguments": kwargs, "error": str(e)}) + raise e + finally: + track_cli(CliTrackingEvents.CommandCompleted, {"command": command, "arguments": kwargs}) + + return wrapper # type: ignore + + return decorator # type: ignore diff --git a/src/together/lib/cli/api/beta/clusters.py b/src/together/lib/cli/api/beta/clusters.py index 038cc571..fbb42231 100644 --- a/src/together/lib/cli/api/beta/clusters.py +++ b/src/together/lib/cli/api/beta/clusters.py @@ -12,6 +12,7 @@ from together._response import APIResponse as APIResponse from together.types.beta import Cluster, ClusterCreateParams from together.lib.cli.api.utils import handle_api_errors +from together.lib.cli._track_cli import auto_track_command from together.types.beta.cluster_create_params import SharedVolume from together.lib.cli.api.beta.clusters_storage import storage @@ -47,6 +48,7 @@ def clusters(ctx: click.Context) -> None: help="Output in JSON format", ) @click.pass_context +@auto_track_command("clusters list") def list(ctx: click.Context, json: bool) -> None: """List clusters""" client: Together = ctx.obj @@ -108,6 +110,7 @@ def list(ctx: click.Context, json: bool) -> None: ) @click.pass_context @handle_api_errors("Clusters") +@auto_track_command("clusters create") def create( ctx: click.Context, name: str | None = None, @@ -239,6 +242,7 @@ def create( ) @click.pass_context @handle_api_errors("Clusters") +@auto_track_command("clusters retrieve") def retrieve(ctx: click.Context, cluster_id: str, json: bool) -> None: """Retrieve a cluster by ID""" client: Together = ctx.obj @@ -273,6 +277,7 @@ def retrieve(ctx: click.Context, cluster_id: str, json: bool) -> None: ) @click.pass_context @handle_api_errors("Clusters") +@auto_track_command("clusters update") def update( ctx: click.Context, cluster_id: str, @@ -308,6 +313,7 @@ def update( ) @click.pass_context @handle_api_errors("Clusters") +@auto_track_command("clusters delete") def delete(ctx: click.Context, cluster_id: str, json: bool) -> None: """Delete a cluster by ID""" client: Together = ctx.obj @@ -336,6 +342,7 @@ def delete(ctx: click.Context, cluster_id: str, json: bool) -> None: ) @click.pass_context @handle_api_errors("Clusters") +@auto_track_command("clusters list-regions") def list_regions(ctx: click.Context, json: bool) -> None: """List regions""" client: Together = ctx.obj diff --git a/src/together/lib/cli/api/beta/clusters_storage.py b/src/together/lib/cli/api/beta/clusters_storage.py index 0a3214fa..2c25b593 100644 --- a/src/together/lib/cli/api/beta/clusters_storage.py +++ b/src/together/lib/cli/api/beta/clusters_storage.py @@ -7,6 +7,7 @@ from together import Together from together.lib.cli.api.utils import handle_api_errors +from together.lib.cli._track_cli import auto_track_command from together.types.beta.clusters import ClusterStorage @@ -56,6 +57,7 @@ def storage(ctx: click.Context) -> None: ) @click.pass_context @handle_api_errors("Clusters Storage") +@auto_track_command("clusters storage create") def create(ctx: click.Context, region: str, size_tib: int, volume_name: str, json: bool) -> None: """Create a storage volume""" client: Together = ctx.obj @@ -85,6 +87,7 @@ def create(ctx: click.Context, region: str, size_tib: int, volume_name: str, jso ) @click.pass_context @handle_api_errors("Clusters Storage") +@auto_track_command("clusters storage retrieve") def retrieve(ctx: click.Context, volume_id: str, json: bool) -> None: """Retrieve a storage volume""" client: Together = ctx.obj @@ -112,6 +115,7 @@ def retrieve(ctx: click.Context, volume_id: str, json: bool) -> None: ) @click.pass_context @handle_api_errors("Clusters Storage") +@auto_track_command("clusters storage delete") def delete(ctx: click.Context, volume_id: str, json: bool) -> None: """Delete a storage volume""" client: Together = ctx.obj @@ -140,6 +144,7 @@ def delete(ctx: click.Context, volume_id: str, json: bool) -> None: ) @click.pass_context @handle_api_errors("Clusters Storage") +@auto_track_command("clusters storage list") def list(ctx: click.Context, json: bool) -> None: """List storage volumes""" client: Together = ctx.obj diff --git a/src/together/lib/cli/api/endpoints.py b/src/together/lib/cli/api/endpoints.py index 8383f629..205a5967 100644 --- a/src/together/lib/cli/api/endpoints.py +++ b/src/together/lib/cli/api/endpoints.py @@ -10,6 +10,7 @@ from together import Together, omit from together.types import DedicatedEndpoint from together._exceptions import APIError +from together.lib.cli._track_cli import auto_track_command from together.lib.utils.serializer import datetime_serializer from together.types.endpoint_list_response import Data as DedicatedEndpointListItem @@ -134,6 +135,7 @@ def endpoints(ctx: click.Context) -> None: ) @click.pass_obj @handle_api_errors +@auto_track_command("endpoints create") def create( client: Together, model: str, @@ -221,6 +223,7 @@ def create( @click.option("--json", is_flag=True, help="Print output in JSON format") @click.pass_obj @handle_api_errors +@auto_track_command("endpoints get") def get(client: Together, endpoint_id: str, json: bool) -> None: """Get a dedicated inference endpoint.""" endpoint = client.endpoints.retrieve(endpoint_id) @@ -242,6 +245,7 @@ def get(client: Together, endpoint_id: str, json: bool) -> None: ) @click.pass_obj @handle_api_errors +@auto_track_command("endpoints hardware") def hardware(client: Together, model: str | None, json: bool, available: bool) -> None: """List all available hardware options, optionally filtered by model.""" fetch_and_print_hardware_options(client, model, json, available) @@ -274,6 +278,7 @@ def fetch_and_print_hardware_options(client: Together, model: str | None, print_ @click.option("--wait/--no-wait", default=True, help="Wait for the endpoint to stop") @click.pass_obj @handle_api_errors +@auto_track_command("endpoints stop") def stop(client: Together, endpoint_id: str, wait: bool) -> None: """Stop a dedicated inference endpoint.""" client.endpoints.update(endpoint_id, state="STOPPED") @@ -295,6 +300,7 @@ def stop(client: Together, endpoint_id: str, wait: bool) -> None: @click.option("--wait/--no-wait", default=True, help="Wait for the endpoint to start") @click.pass_obj @handle_api_errors +@auto_track_command("endpoints start") def start(client: Together, endpoint_id: str, wait: bool) -> None: """Start a dedicated inference endpoint.""" client.endpoints.update(endpoint_id, state="STARTED") @@ -315,6 +321,7 @@ def start(client: Together, endpoint_id: str, wait: bool) -> None: @click.argument("endpoint-id", required=True) @click.pass_obj @handle_api_errors +@auto_track_command("endpoints delete") def delete(client: Together, endpoint_id: str) -> None: """Delete a dedicated inference endpoint.""" client.endpoints.delete(endpoint_id) @@ -342,6 +349,7 @@ def delete(client: Together, endpoint_id: str) -> None: ) @click.pass_obj @handle_api_errors +@auto_track_command("endpoints list") def list( client: Together, json: bool, @@ -400,6 +408,7 @@ def list( ) @click.pass_obj @handle_api_errors +@auto_track_command("endpoints update") def update( client: Together, endpoint_id: str, @@ -449,6 +458,7 @@ def update( @click.option("--json", is_flag=True, help="Print output in JSON format") @click.pass_obj @handle_api_errors +@auto_track_command("endpoints availability-zones") def availability_zones(client: Together, json: bool) -> None: """List all availability zones.""" avzones = client.endpoints.list_avzones() diff --git a/src/together/lib/cli/api/evals.py b/src/together/lib/cli/api/evals.py index 8eb4dba6..807dc0a5 100644 --- a/src/together/lib/cli/api/evals.py +++ b/src/together/lib/cli/api/evals.py @@ -8,6 +8,7 @@ from together import APIError, Together, TogetherError from together._types import omit +from together.lib.cli._track_cli import auto_track_command from together.lib.utils.serializer import datetime_serializer from together.types.eval_create_params import ( ParametersEvaluationScoreParameters, @@ -274,6 +275,7 @@ def evals(ctx: click.Context) -> None: help="Input template for model B.", ) @handle_api_errors +@auto_track_command("evals create") def create( ctx: click.Context, type: Literal["classify", "score", "compare"], @@ -489,6 +491,7 @@ def create( type=int, help="Limit number of results (max 100).", ) +@auto_track_command("evals list") def list( ctx: click.Context, status: Union[Literal["pending", "queued", "running", "completed", "error", "user_error"], None], @@ -530,6 +533,7 @@ def list( @evals.command() @click.pass_context @click.argument("evaluation_id", type=str, required=True) +@auto_track_command("evals retrieve") def retrieve(ctx: click.Context, evaluation_id: str) -> None: """Get details of a specific evaluation job""" @@ -543,6 +547,7 @@ def retrieve(ctx: click.Context, evaluation_id: str) -> None: @evals.command() @click.pass_context @click.argument("evaluation_id", type=str, required=True) +@auto_track_command("evals status") def status(ctx: click.Context, evaluation_id: str) -> None: """Get the status and results of a specific evaluation job""" diff --git a/src/together/lib/cli/api/files.py b/src/together/lib/cli/api/files.py index 9d716c71..d60ea691 100644 --- a/src/together/lib/cli/api/files.py +++ b/src/together/lib/cli/api/files.py @@ -8,6 +8,7 @@ from together import Together from together.types import FilePurpose +from together.lib.cli._track_cli import auto_track_command # from together.utils import check_file, convert_bytes, convert_unix_timestamp from ...utils import check_file, convert_bytes, convert_unix_timestamp @@ -38,6 +39,7 @@ def files(ctx: click.Context) -> None: default=True, help="Whether to check the file before uploading.", ) +@auto_track_command("files upload") def upload(ctx: click.Context, file: pathlib.Path, purpose: FilePurpose, check: bool) -> None: """Upload file""" @@ -50,6 +52,7 @@ def upload(ctx: click.Context, file: pathlib.Path, purpose: FilePurpose, check: @files.command() @click.pass_context +@auto_track_command("files list") def list(ctx: click.Context) -> None: """List files""" client: Together = ctx.obj @@ -75,6 +78,7 @@ def list(ctx: click.Context) -> None: @files.command() @click.pass_context @click.argument("id", type=str, required=True) +@auto_track_command("files retrieve") def retrieve(ctx: click.Context, id: str) -> None: """Upload file""" @@ -89,6 +93,7 @@ def retrieve(ctx: click.Context, id: str) -> None: @click.pass_context @click.argument("id", type=str, required=True) @click.option("--output", type=str, default=None, help="Output filename") +@auto_track_command("files retrieve-content") def retrieve_content(ctx: click.Context, id: str, output: str) -> None: """Retrieve file content and output to file""" @@ -108,6 +113,7 @@ def retrieve_content(ctx: click.Context, id: str, output: str) -> None: @files.command() @click.pass_context @click.argument("id", type=str, required=True) +@auto_track_command("files delete") def delete(ctx: click.Context, id: str) -> None: """Delete remote file""" @@ -125,6 +131,7 @@ def delete(ctx: click.Context, id: str) -> None: type=click.Path(exists=True, file_okay=True, resolve_path=True, readable=True, dir_okay=False), required=True, ) +@auto_track_command("files check") def check(_ctx: click.Context, file: pathlib.Path) -> None: """Check file for issues""" diff --git a/src/together/lib/cli/api/fine_tuning.py b/src/together/lib/cli/api/fine_tuning.py index 549fe569..3efc8087 100644 --- a/src/together/lib/cli/api/fine_tuning.py +++ b/src/together/lib/cli/api/fine_tuning.py @@ -19,6 +19,7 @@ from together.lib.utils import log_warn from together.lib.utils.tools import format_timestamp, finetune_price_to_dollars from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO, generate_progress_bar +from together.lib.cli._track_cli import auto_track_command from together.lib.resources.files import DownloadManager from together.lib.utils.serializer import datetime_serializer from together.types.finetune_response import TrainingTypeFullTrainingType, TrainingTypeLoRaTrainingType @@ -216,6 +217,7 @@ def fine_tuning(ctx: click.Context) -> None: default=None, help="HF repo to upload the fine-tuned model to", ) +@auto_track_command("fine-tuning create") def create( ctx: click.Context, training_file: str, @@ -415,6 +417,7 @@ def create( @fine_tuning.command() @click.pass_context +@auto_track_command("fine-tuning list") def list(ctx: click.Context) -> None: """List fine-tuning jobs""" client: Together = ctx.obj @@ -449,6 +452,7 @@ def list(ctx: click.Context) -> None: @fine_tuning.command() @click.pass_context @click.argument("fine_tune_id", type=str, required=True) +@auto_track_command("fine-tuning retrieve") def retrieve(ctx: click.Context, fine_tune_id: str) -> None: """Retrieve fine-tuning job details""" client: Together = ctx.obj @@ -468,6 +472,7 @@ def retrieve(ctx: click.Context, fine_tune_id: str) -> None: @click.pass_context @click.argument("fine_tune_id", type=str, required=True) @click.option("--quiet", is_flag=True, help="Do not prompt for confirmation before cancelling job") +@auto_track_command("fine-tuning cancel") def cancel(ctx: click.Context, fine_tune_id: str, quiet: bool = False) -> None: """Cancel fine-tuning job""" client: Together = ctx.obj @@ -487,6 +492,7 @@ def cancel(ctx: click.Context, fine_tune_id: str, quiet: bool = False) -> None: @fine_tuning.command() @click.pass_context @click.argument("fine_tune_id", type=str, required=True) +@auto_track_command("fine-tuning list-events") def list_events(ctx: click.Context, fine_tune_id: str) -> None: """List fine-tuning events""" client: Together = ctx.obj @@ -513,6 +519,7 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None: @fine_tuning.command() @click.pass_context @click.argument("fine_tune_id", type=str, required=True) +@auto_track_command("fine-tuning list-checkpoints") def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None: """List available checkpoints for a fine-tuning job""" client: Together = ctx.obj @@ -569,6 +576,7 @@ def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None: default="merged", help="Specifies checkpoint type. 'merged' and 'adapter' options work only for LoRA jobs.", ) +@auto_track_command("fine-tuning download") def download( ctx: click.Context, fine_tune_id: str, @@ -628,6 +636,7 @@ def download( @click.argument("fine_tune_id", type=str, required=True) @click.option("--force", is_flag=True, help="Force deletion without confirmation") @click.option("--quiet", is_flag=True, help="Do not prompt for confirmation before deleting job") +@auto_track_command("fine-tuning delete") def delete(ctx: click.Context, fine_tune_id: str, force: bool = False, quiet: bool = False) -> None: """Delete fine-tuning job""" client: Together = ctx.obj diff --git a/src/together/lib/cli/api/models.py b/src/together/lib/cli/api/models.py index cb566695..737a4d31 100644 --- a/src/together/lib/cli/api/models.py +++ b/src/together/lib/cli/api/models.py @@ -7,6 +7,7 @@ from together import Together, omit from together._models import BaseModel from together._response import APIResponse as APIResponse +from together.lib.cli._track_cli import auto_track_command from together.types.model_upload_response import ModelUploadResponse @@ -29,12 +30,16 @@ def models(ctx: click.Context) -> None: help="Output in JSON format", ) @click.pass_context +@auto_track_command("models list") def list(ctx: click.Context, type: Optional[str], json: bool) -> None: """List models""" client: Together = ctx.obj models_list = client.models.list(dedicated=type == "dedicated" if type else omit) + if json: + click.echo(json_lib.dumps(models_list, indent=2)) + display_list: List[Dict[str, Any]] = [] model: BaseModel for model in models_list: @@ -51,10 +56,7 @@ def list(ctx: click.Context, type: Optional[str], json: bool) -> None: } ) - if json: - click.echo(json_lib.dumps(display_list, indent=2)) - else: - click.echo(tabulate(display_list, headers="keys", tablefmt="plain")) + click.echo(tabulate(display_list, headers="keys", tablefmt="plain")) @models.command() @@ -96,6 +98,7 @@ def list(ctx: click.Context, type: Optional[str], json: bool) -> None: help="Output in JSON format", ) @click.pass_context +@auto_track_command("models upload") def upload( ctx: click.Context, model_name: str, diff --git a/src/together/lib/cli/cli.py b/src/together/lib/cli/cli.py index c3f01d49..f549b042 100644 --- a/src/together/lib/cli/cli.py +++ b/src/together/lib/cli/cli.py @@ -4,12 +4,14 @@ from typing import Any import click +import httpx import together from together._version import __version__ from together._constants import DEFAULT_TIMEOUT from together.lib.cli.api.evals import evals from together.lib.cli.api.files import files +from together.lib.cli._track_cli import CliTrackingEvents, track_cli from together.lib.cli.api.models import models from together.lib.cli.api.beta.beta import beta from together.lib.cli.api.endpoints import endpoints @@ -57,10 +59,22 @@ def main( ) -> None: """This is a sample CLI tool.""" os.environ.setdefault("TOGETHER_LOG", "debug" if debug else "info") - ctx.obj = together.Together( + + client = together.Together( api_key=api_key, base_url=base_url, timeout=timeout, max_retries=max_retries if max_retries is not None else 0 ) + # Wrap the client's httpx requests to track the parameters sent on api requests + def track_request(request: httpx.Request) -> None: + track_cli( + CliTrackingEvents.ApiRequest, + {"url": str(request.url), "method": request.method, "body": request.content.decode("utf-8")}, + ) + + client._client.event_hooks["request"].append(track_request) + + ctx.obj = client + main.add_command(files) main.add_command(fine_tuning) diff --git a/src/together/resources/beta/clusters/clusters.py b/src/together/resources/beta/clusters/clusters.py index d18f215c..c3e89a8c 100644 --- a/src/together/resources/beta/clusters/clusters.py +++ b/src/together/resources/beta/clusters/clusters.py @@ -80,8 +80,13 @@ def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterCreateResponse: - """ - Create GPU Cluster + """Create an Instant Cluster on Together's high-performance GPU clusters. + + With + features like on-demand scaling, long-lived resizable high-bandwidth shared + DC-local storage, Kubernetes and Slurm cluster flavors, a REST API, and + Terraform support, you can run workloads flexibly without complex infrastructure + management. Args: cluster_name: Name of the GPU cluster. @@ -141,7 +146,7 @@ def retrieve( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> Cluster: """ - Get GPU cluster by cluster ID + Retrieve information about a specific GPU cluster. Args: extra_headers: Send extra headers @@ -176,7 +181,7 @@ def update( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterUpdateResponse: """ - Update a GPU Cluster. + Update the configuration of an existing GPU cluster. Args: extra_headers: Send extra headers @@ -235,7 +240,7 @@ def delete( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterDeleteResponse: """ - Delete GPU cluster by cluster ID + Delete a GPU cluster by cluster ID. Args: extra_headers: Send extra headers @@ -320,8 +325,13 @@ async def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterCreateResponse: - """ - Create GPU Cluster + """Create an Instant Cluster on Together's high-performance GPU clusters. + + With + features like on-demand scaling, long-lived resizable high-bandwidth shared + DC-local storage, Kubernetes and Slurm cluster flavors, a REST API, and + Terraform support, you can run workloads flexibly without complex infrastructure + management. Args: cluster_name: Name of the GPU cluster. @@ -381,7 +391,7 @@ async def retrieve( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> Cluster: """ - Get GPU cluster by cluster ID + Retrieve information about a specific GPU cluster. Args: extra_headers: Send extra headers @@ -416,7 +426,7 @@ async def update( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterUpdateResponse: """ - Update a GPU Cluster. + Update the configuration of an existing GPU cluster. Args: extra_headers: Send extra headers @@ -475,7 +485,7 @@ async def delete( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterDeleteResponse: """ - Delete GPU cluster by cluster ID + Delete a GPU cluster by cluster ID. Args: extra_headers: Send extra headers diff --git a/src/together/resources/beta/clusters/storage.py b/src/together/resources/beta/clusters/storage.py index c7915737..508f9471 100644 --- a/src/together/resources/beta/clusters/storage.py +++ b/src/together/resources/beta/clusters/storage.py @@ -57,12 +57,15 @@ def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> StorageCreateResponse: - """Create a shared volume. + """ + Instant Clusters supports long-lived, resizable in-DC shared storage with user + data persistence. You can dynamically create and attach volumes to your cluster + at cluster creation time, and resize as your data grows. All shared storage is + backed by multi-NIC bare metal paths, ensuring high-throughput and low-latency + performance for shared storage. Args: - region: Region name. - - Usable regions can be found from `client.clusters.list_regions()` + region: Region name. Usable regions can be found from `client.clusters.list_regions()` size_tib: Volume size in whole tebibytes (TiB). @@ -102,7 +105,7 @@ def retrieve( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterStorage: """ - Get shared volume by volume Id. + Retrieve information about a specific shared volume. Args: extra_headers: Send extra headers @@ -136,7 +139,7 @@ def update( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterStorage: """ - Update a shared volume. + Update the configuration of an existing shared volume. Args: extra_headers: Send extra headers @@ -192,8 +195,10 @@ def delete( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> StorageDeleteResponse: - """ - Delete shared volume by volume id. + """Delete a shared volume. + + Note that if this volume is attached to a cluster, + deleting will fail. Args: extra_headers: Send extra headers @@ -248,12 +253,15 @@ async def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> StorageCreateResponse: - """Create a shared volume. + """ + Instant Clusters supports long-lived, resizable in-DC shared storage with user + data persistence. You can dynamically create and attach volumes to your cluster + at cluster creation time, and resize as your data grows. All shared storage is + backed by multi-NIC bare metal paths, ensuring high-throughput and low-latency + performance for shared storage. Args: - region: Region name. - - Usable regions can be found from `client.clusters.list_regions()` + region: Region name. Usable regions can be found from `client.clusters.list_regions()` size_tib: Volume size in whole tebibytes (TiB). @@ -293,7 +301,7 @@ async def retrieve( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterStorage: """ - Get shared volume by volume Id. + Retrieve information about a specific shared volume. Args: extra_headers: Send extra headers @@ -327,7 +335,7 @@ async def update( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterStorage: """ - Update a shared volume. + Update the configuration of an existing shared volume. Args: extra_headers: Send extra headers @@ -383,8 +391,10 @@ async def delete( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> StorageDeleteResponse: - """ - Delete shared volume by volume id. + """Delete a shared volume. + + Note that if this volume is attached to a cluster, + deleting will fail. Args: extra_headers: Send extra headers diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index 0aa7ff03..645b89dd 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -1,6 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from __future__ import annotations # noqa +from __future__ import annotations # noqa from .batch_job import BatchJob as BatchJob from .embedding import Embedding as Embedding @@ -70,12 +70,13 @@ ) # Manually added to minimize breaking changes from V1 -from .chat.chat_completion import ChatCompletion +from .chat.chat_completion import ChatCompletion from .chat.chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunk from .chat.chat_completion_usage import ChatCompletionUsage + UsageData = ChatCompletionUsage ChatCompletionResponse = ChatCompletion CompletionResponse = Completion ListEndpoint = EndpointListResponse ImageRequest = ImageGenerateParams -ImageResponse = ImageFile \ No newline at end of file +ImageResponse = ImageFile diff --git a/src/together/types/chat_completions.py b/src/together/types/chat_completions.py index 956a6fda..dfa7f175 100644 --- a/src/together/types/chat_completions.py +++ b/src/together/types/chat_completions.py @@ -4,4 +4,4 @@ from .chat.chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunk ChatCompletionResponse = ChatCompletion -ToolCalls = ToolChoice \ No newline at end of file +ToolCalls = ToolChoice diff --git a/src/together/types/endpoints.py b/src/together/types/endpoints.py index b2e49e6d..808df6ca 100644 --- a/src/together/types/endpoints.py +++ b/src/together/types/endpoints.py @@ -1,4 +1,4 @@ # Manually added to minimize breaking changes from V1 from together.types import DedicatedEndpoint as DedicatedEndpoint -ListEndpoint = DedicatedEndpoint \ No newline at end of file +ListEndpoint = DedicatedEndpoint diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index 8d371cac..a3322f1d 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -32,7 +32,7 @@ def create_finetune_response( return FinetuneResponse( id=job_id, progress=progress, - updated_at=started_at, # to calm down mypy + updated_at=started_at, # to calm down mypy started_at=started_at, status=status, created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), diff --git a/tests/test_client.py b/tests/test_client.py index 579f2f6b..011fef4b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,10 +8,11 @@ import json import asyncio import inspect +import dataclasses import tracemalloc -from typing import Any, Union, cast +from typing import Any, Union, TypeVar, Callable, Iterable, Iterator, Optional, Coroutine, cast from unittest import mock -from typing_extensions import Literal +from typing_extensions import Literal, AsyncIterator, override import httpx import pytest @@ -37,6 +38,7 @@ from .utils import update_env +T = TypeVar("T") base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") api_key = "My API Key" @@ -51,6 +53,57 @@ def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float: return 0.1 +def mirror_request_content(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=request.content) + + +# note: we can't use the httpx.MockTransport class as it consumes the request +# body itself, which means we can't test that the body is read lazily +class MockTransport(httpx.BaseTransport, httpx.AsyncBaseTransport): + def __init__( + self, + handler: Callable[[httpx.Request], httpx.Response] + | Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]], + ) -> None: + self.handler = handler + + @override + def handle_request( + self, + request: httpx.Request, + ) -> httpx.Response: + assert not inspect.iscoroutinefunction(self.handler), "handler must not be a coroutine function" + assert inspect.isfunction(self.handler), "handler must be a function" + return self.handler(request) + + @override + async def handle_async_request( + self, + request: httpx.Request, + ) -> httpx.Response: + assert inspect.iscoroutinefunction(self.handler), "handler must be a coroutine function" + return await self.handler(request) + + +@dataclasses.dataclass +class Counter: + value: int = 0 + + +def _make_sync_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> Iterator[T]: + for item in iterable: + if counter: + counter.value += 1 + yield item + + +async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> AsyncIterator[T]: + for item in iterable: + if counter: + counter.value += 1 + yield item + + def _get_open_connections(client: Together | AsyncTogether) -> int: transport = client._client._transport assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport) @@ -503,6 +556,70 @@ def test_multipart_repeating_array(self, client: Together) -> None: b"", ] + @pytest.mark.respx(base_url=base_url) + def test_binary_content_upload(self, respx_mock: MockRouter, client: Together) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + response = client.post( + "/upload", + content=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + def test_binary_content_upload_with_iterator(self) -> None: + file_content = b"Hello, this is a test file." + counter = Counter() + iterator = _make_sync_iterator([file_content], counter=counter) + + def mock_handler(request: httpx.Request) -> httpx.Response: + assert counter.value == 0, "the request body should not have been read" + return httpx.Response(200, content=request.read()) + + with Together( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client(transport=MockTransport(handler=mock_handler)), + ) as client: + response = client.post( + "/upload", + content=iterator, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + assert counter.value == 1 + + @pytest.mark.respx(base_url=base_url) + def test_binary_content_upload_with_body_is_deprecated(self, respx_mock: MockRouter, client: Together) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + with pytest.deprecated_call( + match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead." + ): + response = client.post( + "/upload", + body=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + @pytest.mark.respx(base_url=base_url) def test_basic_union_response(self, respx_mock: MockRouter, client: Together) -> None: class Model1(BaseModel): @@ -1375,6 +1492,72 @@ def test_multipart_repeating_array(self, async_client: AsyncTogether) -> None: b"", ] + @pytest.mark.respx(base_url=base_url) + async def test_binary_content_upload(self, respx_mock: MockRouter, async_client: AsyncTogether) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + response = await async_client.post( + "/upload", + content=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + async def test_binary_content_upload_with_asynciterator(self) -> None: + file_content = b"Hello, this is a test file." + counter = Counter() + iterator = _make_async_iterator([file_content], counter=counter) + + async def mock_handler(request: httpx.Request) -> httpx.Response: + assert counter.value == 0, "the request body should not have been read" + return httpx.Response(200, content=await request.aread()) + + async with AsyncTogether( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient(transport=MockTransport(handler=mock_handler)), + ) as client: + response = await client.post( + "/upload", + content=iterator, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + assert counter.value == 1 + + @pytest.mark.respx(base_url=base_url) + async def test_binary_content_upload_with_body_is_deprecated( + self, respx_mock: MockRouter, async_client: AsyncTogether + ) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + with pytest.deprecated_call( + match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead." + ): + response = await async_client.post( + "/upload", + body=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + @pytest.mark.respx(base_url=base_url) async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncTogether) -> None: class Model1(BaseModel): diff --git a/uv.lock b/uv.lock index ffc03b56..6c3569f8 100644 --- a/uv.lock +++ b/uv.lock @@ -1213,6 +1213,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/5a/bc7b4a4ef808fa59a816c17b20c4bef6884daebbdf627ff2a161da67da19/propcache-0.4.1-py3-none-any.whl", hash = "sha256:af2a6052aeb6cf17d3e46ee169099044fd8224cbaf75c76a2ef596e8163e2237", size = 13305, upload-time = "2025-10-08T19:49:00.792Z" }, ] +[[package]] +name = "py-machineid" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "winregistry", marker = "sys_platform == 'win32' or (extra == 'group-8-together-pydantic-v1' and extra == 'group-8-together-pydantic-v2')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/b0/c7fa6de7298a8f4e544929b97fa028304c0e11a4bc9500eff8689821bdbb/py_machineid-1.0.0.tar.gz", hash = "sha256:8a902a00fae8c6d6433f463697c21dc4ce98c6e55a2e0535c0273319acb0047a", size = 4629, upload-time = "2025-12-02T16:12:54.286Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/76/1ed8375cb1212824c57eb706e1f09f3f2ca4ed12b8d56b28a160e2d53505/py_machineid-1.0.0-py3-none-any.whl", hash = "sha256:910df0d5f2663bcf6739d835c4949f4e9cc6bb090a58b3dd766e12e5f768e3b9", size = 4926, upload-time = "2025-12-02T16:12:20.584Z" }, +] + [[package]] name = "pyarrow" version = "21.0.0" @@ -1963,7 +1975,7 @@ wheels = [ [[package]] name = "together" -version = "2.0.0a14" +version = "2.0.0a15" source = { editable = "." } dependencies = [ { name = "anyio" }, @@ -1975,6 +1987,7 @@ dependencies = [ { name = "httpx" }, { name = "pillow", version = "11.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10' or (extra == 'group-8-together-pydantic-v1' and extra == 'group-8-together-pydantic-v2')" }, { name = "pillow", version = "12.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10' or (extra == 'group-8-together-pydantic-v1' and extra == 'group-8-together-pydantic-v2')" }, + { name = "py-machineid" }, { name = "pydantic", version = "1.10.24", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'group-8-together-pydantic-v1'" }, { name = "pydantic", version = "2.12.5", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'group-8-together-pydantic-v2' or extra != 'group-8-together-pydantic-v1'" }, { name = "rich" }, @@ -2033,6 +2046,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.23.0,<1" }, { name = "httpx-aiohttp", marker = "extra == 'aiohttp'", specifier = ">=0.1.9" }, { name = "pillow", specifier = ">=10.4.0" }, + { name = "py-machineid", specifier = ">=1.0.0" }, { name = "pyarrow", marker = "extra == 'pyarrow'", specifier = ">=16.1.0" }, { name = "pyarrow-stubs", marker = "extra == 'pyarrow'", specifier = ">=10.0.1.7" }, { name = "pydantic", specifier = ">=1.9.0,<3" }, @@ -2192,6 +2206,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "winregistry" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/86/94/ddc339d2562267af7d25d5067874f7df8c6c19ab9dd976fa830982b1c398/winregistry-2.1.2.tar.gz", hash = "sha256:50260e1aaba4116f707f86a4e287ffcb1eeae7dc0a0883c6d1776017e693fc69", size = 9538, upload-time = "2025-10-09T09:25:07.391Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/dd/5a18d9fbf9a3d69b40e395d80779dfaeda77b98c946df36bf7df41ddcaa5/winregistry-2.1.2-py3-none-any.whl", hash = "sha256:e142548f56fc1fc6b83ddf88baca2e9e18cd6a266d9e00f111e54977dee768cf", size = 8507, upload-time = "2025-10-09T09:25:05.82Z" }, +] + [[package]] name = "yarl" version = "1.22.0"