diff --git a/.stats.yml b/.stats.yml
index a3e1592f..edc1c05a 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
configured_endpoints: 55
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-51627d25c5c4ea3cf03c92a335acf66cf8cad652079915109fe9711a57f7e003.yml
-openapi_spec_hash: 97f97a89965aa05900566ca2824a4de1
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-28d95f054f8ffe846f42a014a1c86d0385604e9d05850c17819ef986a068ac88.yml
+openapi_spec_hash: 83d5ac256007e9ccd40abe11a5983168
config_hash: 6acd26f13abe2b4550fb4bbb06d31523
diff --git a/examples/tokenize_data.py b/examples/tokenize_data.py
new file mode 100644
index 00000000..1b9d0035
--- /dev/null
+++ b/examples/tokenize_data.py
@@ -0,0 +1,222 @@
+import logging
+import argparse
+from typing import Dict, List
+from functools import partial
+from multiprocessing import cpu_count
+
+from datasets import Dataset, load_dataset # type: ignore
+from transformers import ( # type: ignore
+ AutoTokenizer,
+ BatchEncoding,
+ PreTrainedTokenizerBase,
+)
+
+# see default of ignore_index
+# for https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
+LOSS_IGNORE_INDEX = -100
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def tokenize_variable_length(
+ data: Dict[str, str],
+ tokenizer: PreTrainedTokenizerBase,
+ add_special_tokens: bool = True,
+) -> BatchEncoding:
+ tokenized = tokenizer(data["text"], add_special_tokens=add_special_tokens, truncation=False)
+ return tokenized
+
+
+def tokenize_constant_length(
+ data: Dict[str, str],
+ tokenizer: PreTrainedTokenizerBase,
+ max_length: int = 2048,
+ add_special_tokens: bool = True,
+ add_labels: bool = True,
+) -> BatchEncoding:
+ # tokenized contains `input_ids` and `attention_mask`
+ tokenized: BatchEncoding = tokenizer(
+ data["text"],
+ max_length=max_length,
+ truncation=True,
+ padding="max_length",
+ add_special_tokens=add_special_tokens,
+ )
+ # add labels to mask out any padding tokens
+ if add_labels:
+ tokenized["labels"] = [
+ LOSS_IGNORE_INDEX if token_id == tokenizer.pad_token_id else token_id for token_id in tokenized["input_ids"]
+ ]
+
+ return tokenized
+
+
+def pack_sequences(
+ batch: BatchEncoding,
+ max_seq_len: int,
+ pad_token_id: int,
+ eos_token_id: int,
+ add_labels: bool,
+ cutoff_size: int = 0,
+) -> Dict[str, List[List[int]]]:
+ """
+ cutoff_size = max_seq_len means that we will drop any non-full sequences
+ (full packing without padding)
+ Example:
+ Sequence 1:
+ ['', '▁usually', '▁,', '▁he', '▁would', '▁be', '▁t', 'earing']
+ Sequence 2:
+ ['▁around', '▁the', '▁living', '▁room', '▁,', '▁playing', '▁with', '▁his']
+ Sequence 3:
+ ['▁toys', '▁.', '', '', '▁but', '▁just', '▁one', '▁look']
+ """
+ packed_sequences = []
+ buffer = []
+
+ for input_ids in batch["input_ids"]:
+ # Add the current sequence to the buffer
+ buffer.extend(input_ids)
+ buffer.append(eos_token_id) # Add EOS at the end of each sequence
+
+ # Check if buffer needs to be split into chunks
+ while len(buffer) > max_seq_len:
+ # Take a full chunk from the buffer and append it to packed_sequences
+ packed_sequences.append(buffer[:max_seq_len])
+ # Remove the processed chunk from the buffer
+ buffer = buffer[max_seq_len:]
+
+ # Add the last buffer if it's exactly chunk_size
+ if len(buffer) == max_seq_len:
+ packed_sequences.append(buffer)
+ elif len(buffer) > cutoff_size:
+ # if the buffer is larger than the cutoff size, pad it to the chunk_size
+ # if not, we do not include in the packed_sequences
+ buffer.extend([pad_token_id] * (max_seq_len - len(buffer)))
+ packed_sequences.append(buffer)
+
+ output = {"input_ids": packed_sequences}
+ if add_labels:
+ output["labels"] = [
+ [LOSS_IGNORE_INDEX if token_id == pad_token_id else token_id for token_id in example]
+ for example in output["input_ids"]
+ ]
+
+ # mask attention for padding tokens, a better version would also mask cross-sequence dependencies
+ output["attention_mask"] = [
+ [0 if token_id == pad_token_id else 1 for token_id in example] for example in output["input_ids"]
+ ]
+ return output
+
+
+def process_fast_packing(
+ dataset: Dataset,
+ tokenizer: PreTrainedTokenizerBase,
+ max_sequence_length: int,
+ add_labels: bool,
+ add_special_tokens: bool,
+) -> Dataset:
+ tokenized_dataset = dataset.map(
+ lambda examples: tokenize_variable_length(examples, tokenizer, add_special_tokens=add_special_tokens),
+ batched=True,
+ num_proc=cpu_count(),
+ load_from_cache_file=True,
+ remove_columns=dataset.column_names,
+ )
+ logger.info(f"tokenized dataset: {tokenized_dataset}")
+
+ packed_dataset = tokenized_dataset.map(
+ lambda batch: pack_sequences(
+ batch,
+ max_sequence_length,
+ tokenizer.pad_token_id,
+ tokenizer.eos_token_id,
+ add_labels=add_labels,
+ cutoff_size=max_sequence_length,
+ ),
+ batched=True,
+ num_proc=cpu_count() if len(tokenized_dataset) > 10000 else 1,
+ remove_columns=["attention_mask"],
+ )
+ logger.info(f"Packed dataset: {packed_dataset}")
+ return packed_dataset
+
+
+def process_data(args: argparse.Namespace) -> None:
+ if not args.out_filename.endswith(".parquet"):
+ raise ValueError("`--out_filename` should have the `.parquet` extension")
+
+ dataset = load_dataset(args.dataset, split="train")
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
+ tokenizer.pad_token = tokenizer.eos_token
+
+ dataset.to_json("dataset.jsonl", orient="records", lines=True)
+
+ if not args.packing:
+ tokenized_data = dataset.map(
+ partial(
+ tokenize_constant_length,
+ tokenizer=tokenizer,
+ max_length=args.max_seq_length,
+ add_special_tokens=True,
+ add_labels=args.add_labels,
+ ),
+ batched=False,
+ num_proc=cpu_count(),
+ remove_columns=dataset.column_names,
+ )
+ else:
+ tokenized_data = process_fast_packing(
+ dataset,
+ tokenizer,
+ max_sequence_length=args.max_seq_length,
+ add_labels=args.add_labels,
+ add_special_tokens=True,
+ )
+
+ assert "input_ids" in tokenized_data.column_names and "attention_mask" in tokenized_data.column_names
+
+ if args.add_labels:
+ assert "labels" in tokenized_data.column_names
+
+ logger.info("Tokenized data:")
+ print(tokenized_data)
+
+ logger.info(f"Saving data to {args.out_filename}")
+ print(len(tokenized_data[0]["input_ids"]))
+ tokenized_data.to_parquet(args.out_filename)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Pretokenize examples for finetuning via Together")
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="clam004/antihallucination_dataset",
+ help="Dataset name on the Hugging Face Hub",
+ )
+ parser.add_argument("--max-seq-length", type=int, default=8192, help="Maximum sequence length")
+ parser.add_argument(
+ "--add-labels",
+ action="store_true",
+ help="Whether to add loss labels from padding tokens",
+ )
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ required=True,
+ help="Tokenizer name (for example, togethercomputer/Llama-3-8b-hf)",
+ )
+ parser.add_argument(
+ "--out-filename",
+ default="processed_dataset.parquet",
+ help="Name of the Parquet file to save (should have .parquet extension)",
+ )
+ parser.add_argument(
+ "--packing",
+ action="store_true",
+ help="Whether to pack shorter sequences up to `--max-seq-length`",
+ )
+ args = parser.parse_args()
+
+ process_data(args)
diff --git a/pyproject.toml b/pyproject.toml
index c90ff583..2cbfb713 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,14 +15,15 @@ dependencies = [
"anyio>=3.5.0, <5",
"distro>=1.7.0, <2",
"sniffio",
- "click>=8.1.7",
- "rich>=13.7.1",
- "tabulate>=0.9.0",
- "pillow>=10.4.0",
- "types-tabulate>=0.9.0.20240106",
- "tqdm>=4.67.1",
- "types-tqdm>=4.67.0.20250516",
- "filelock>=3.13.1",
+ "click>=8.1.7",
+ "rich>=13.7.1",
+ "tabulate>=0.9.0",
+ "pillow>=10.4.0",
+ "types-tabulate>=0.9.0.20240106",
+ "tqdm>=4.67.1",
+ "types-tqdm>=4.67.0.20250516",
+ "filelock>=3.13.1",
+ "py-machineid>=1.0.0",
]
requires-python = ">= 3.9"
@@ -150,6 +151,7 @@ exclude = [
".venv",
".nox",
".git",
+ "examples",
]
reportImplicitOverride = true
diff --git a/src/together/_base_client.py b/src/together/_base_client.py
index 3a631d34..9838c526 100644
--- a/src/together/_base_client.py
+++ b/src/together/_base_client.py
@@ -9,6 +9,7 @@
import inspect
import logging
import platform
+import warnings
import email.utils
from types import TracebackType
from random import random
@@ -51,9 +52,11 @@
ResponseT,
AnyMapping,
PostParser,
+ BinaryTypes,
RequestFiles,
HttpxSendArgs,
RequestOptions,
+ AsyncBinaryTypes,
HttpxRequestFiles,
ModelBuilderProtocol,
not_given,
@@ -477,8 +480,19 @@ def _build_request(
retries_taken: int = 0,
) -> httpx.Request:
if log.isEnabledFor(logging.DEBUG):
- log.debug("Request options: %s", model_dump(options, exclude_unset=True))
-
+ log.debug(
+ "Request options: %s",
+ model_dump(
+ options,
+ exclude_unset=True,
+ # Pydantic v1 can't dump every type we support in content, so we exclude it for now.
+ exclude={
+ "content",
+ }
+ if PYDANTIC_V1
+ else {},
+ ),
+ )
kwargs: dict[str, Any] = {}
json_data = options.json_data
@@ -532,7 +546,13 @@ def _build_request(
is_body_allowed = options.method.lower() != "get"
if is_body_allowed:
- if isinstance(json_data, bytes):
+ if options.content is not None and json_data is not None:
+ raise TypeError("Passing both `content` and `json_data` is not supported")
+ if options.content is not None and files is not None:
+ raise TypeError("Passing both `content` and `files` is not supported")
+ if options.content is not None:
+ kwargs["content"] = options.content
+ elif isinstance(json_data, bytes):
kwargs["content"] = json_data
else:
kwargs["json"] = json_data if is_given(json_data) else None
@@ -1194,6 +1214,7 @@ def post(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: BinaryTypes | None = None,
options: RequestOptions = {},
files: RequestFiles | None = None,
stream: Literal[False] = False,
@@ -1206,6 +1227,7 @@ def post(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: BinaryTypes | None = None,
options: RequestOptions = {},
files: RequestFiles | None = None,
stream: Literal[True],
@@ -1219,6 +1241,7 @@ def post(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: BinaryTypes | None = None,
options: RequestOptions = {},
files: RequestFiles | None = None,
stream: bool,
@@ -1231,13 +1254,25 @@ def post(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: BinaryTypes | None = None,
options: RequestOptions = {},
files: RequestFiles | None = None,
stream: bool = False,
stream_cls: type[_StreamT] | None = None,
) -> ResponseT | _StreamT:
+ if body is not None and content is not None:
+ raise TypeError("Passing both `body` and `content` is not supported")
+ if files is not None and content is not None:
+ raise TypeError("Passing both `files` and `content` is not supported")
+ if isinstance(body, bytes):
+ warnings.warn(
+ "Passing raw bytes as `body` is deprecated and will be removed in a future version. "
+ "Please pass raw bytes via the `content` parameter instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
opts = FinalRequestOptions.construct(
- method="post", url=path, json_data=body, files=to_httpx_files(files), **options
+ method="post", url=path, json_data=body, content=content, files=to_httpx_files(files), **options
)
return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))
@@ -1247,11 +1282,23 @@ def patch(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: BinaryTypes | None = None,
files: RequestFiles | None = None,
options: RequestOptions = {},
) -> ResponseT:
+ if body is not None and content is not None:
+ raise TypeError("Passing both `body` and `content` is not supported")
+ if files is not None and content is not None:
+ raise TypeError("Passing both `files` and `content` is not supported")
+ if isinstance(body, bytes):
+ warnings.warn(
+ "Passing raw bytes as `body` is deprecated and will be removed in a future version. "
+ "Please pass raw bytes via the `content` parameter instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
opts = FinalRequestOptions.construct(
- method="patch", url=path, json_data=body, files=to_httpx_files(files), **options
+ method="patch", url=path, json_data=body, content=content, files=to_httpx_files(files), **options
)
return self.request(cast_to, opts)
@@ -1261,11 +1308,23 @@ def put(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: BinaryTypes | None = None,
files: RequestFiles | None = None,
options: RequestOptions = {},
) -> ResponseT:
+ if body is not None and content is not None:
+ raise TypeError("Passing both `body` and `content` is not supported")
+ if files is not None and content is not None:
+ raise TypeError("Passing both `files` and `content` is not supported")
+ if isinstance(body, bytes):
+ warnings.warn(
+ "Passing raw bytes as `body` is deprecated and will be removed in a future version. "
+ "Please pass raw bytes via the `content` parameter instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
opts = FinalRequestOptions.construct(
- method="put", url=path, json_data=body, files=to_httpx_files(files), **options
+ method="put", url=path, json_data=body, content=content, files=to_httpx_files(files), **options
)
return self.request(cast_to, opts)
@@ -1275,9 +1334,19 @@ def delete(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: BinaryTypes | None = None,
options: RequestOptions = {},
) -> ResponseT:
- opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options)
+ if body is not None and content is not None:
+ raise TypeError("Passing both `body` and `content` is not supported")
+ if isinstance(body, bytes):
+ warnings.warn(
+ "Passing raw bytes as `body` is deprecated and will be removed in a future version. "
+ "Please pass raw bytes via the `content` parameter instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, content=content, **options)
return self.request(cast_to, opts)
def get_api_list(
@@ -1717,6 +1786,7 @@ async def post(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: AsyncBinaryTypes | None = None,
files: RequestFiles | None = None,
options: RequestOptions = {},
stream: Literal[False] = False,
@@ -1729,6 +1799,7 @@ async def post(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: AsyncBinaryTypes | None = None,
files: RequestFiles | None = None,
options: RequestOptions = {},
stream: Literal[True],
@@ -1742,6 +1813,7 @@ async def post(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: AsyncBinaryTypes | None = None,
files: RequestFiles | None = None,
options: RequestOptions = {},
stream: bool,
@@ -1754,13 +1826,25 @@ async def post(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: AsyncBinaryTypes | None = None,
files: RequestFiles | None = None,
options: RequestOptions = {},
stream: bool = False,
stream_cls: type[_AsyncStreamT] | None = None,
) -> ResponseT | _AsyncStreamT:
+ if body is not None and content is not None:
+ raise TypeError("Passing both `body` and `content` is not supported")
+ if files is not None and content is not None:
+ raise TypeError("Passing both `files` and `content` is not supported")
+ if isinstance(body, bytes):
+ warnings.warn(
+ "Passing raw bytes as `body` is deprecated and will be removed in a future version. "
+ "Please pass raw bytes via the `content` parameter instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
opts = FinalRequestOptions.construct(
- method="post", url=path, json_data=body, files=await async_to_httpx_files(files), **options
+ method="post", url=path, json_data=body, content=content, files=await async_to_httpx_files(files), **options
)
return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)
@@ -1770,11 +1854,28 @@ async def patch(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: AsyncBinaryTypes | None = None,
files: RequestFiles | None = None,
options: RequestOptions = {},
) -> ResponseT:
+ if body is not None and content is not None:
+ raise TypeError("Passing both `body` and `content` is not supported")
+ if files is not None and content is not None:
+ raise TypeError("Passing both `files` and `content` is not supported")
+ if isinstance(body, bytes):
+ warnings.warn(
+ "Passing raw bytes as `body` is deprecated and will be removed in a future version. "
+ "Please pass raw bytes via the `content` parameter instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
opts = FinalRequestOptions.construct(
- method="patch", url=path, json_data=body, files=await async_to_httpx_files(files), **options
+ method="patch",
+ url=path,
+ json_data=body,
+ content=content,
+ files=await async_to_httpx_files(files),
+ **options,
)
return await self.request(cast_to, opts)
@@ -1784,11 +1885,23 @@ async def put(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: AsyncBinaryTypes | None = None,
files: RequestFiles | None = None,
options: RequestOptions = {},
) -> ResponseT:
+ if body is not None and content is not None:
+ raise TypeError("Passing both `body` and `content` is not supported")
+ if files is not None and content is not None:
+ raise TypeError("Passing both `files` and `content` is not supported")
+ if isinstance(body, bytes):
+ warnings.warn(
+ "Passing raw bytes as `body` is deprecated and will be removed in a future version. "
+ "Please pass raw bytes via the `content` parameter instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
opts = FinalRequestOptions.construct(
- method="put", url=path, json_data=body, files=await async_to_httpx_files(files), **options
+ method="put", url=path, json_data=body, content=content, files=await async_to_httpx_files(files), **options
)
return await self.request(cast_to, opts)
@@ -1798,9 +1911,19 @@ async def delete(
*,
cast_to: Type[ResponseT],
body: Body | None = None,
+ content: AsyncBinaryTypes | None = None,
options: RequestOptions = {},
) -> ResponseT:
- opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options)
+ if body is not None and content is not None:
+ raise TypeError("Passing both `body` and `content` is not supported")
+ if isinstance(body, bytes):
+ warnings.warn(
+ "Passing raw bytes as `body` is deprecated and will be removed in a future version. "
+ "Please pass raw bytes via the `content` parameter instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, content=content, **options)
return await self.request(cast_to, opts)
def get_api_list(
diff --git a/src/together/_client.py b/src/together/_client.py
index 452d9387..aa85275c 100644
--- a/src/together/_client.py
+++ b/src/together/_client.py
@@ -3,11 +3,14 @@
from __future__ import annotations
import os
+import sys
from typing import TYPE_CHECKING, Any, Mapping
from typing_extensions import Self, override
import httpx
+from together.lib._google_colab import get_google_colab_secret
+
from . import _exceptions
from ._qs import Querystring
from ._types import (
@@ -113,6 +116,8 @@ def __init__(
"""
if api_key is None:
api_key = os.environ.get("TOGETHER_API_KEY")
+ if api_key is None and "google.colab" in sys.modules:
+ api_key = get_google_colab_secret("TOGETHER_API_KEY")
if api_key is None:
raise TogetherError(
"The api_key client option must be set either by passing api_key to the client or by setting the TOGETHER_API_KEY environment variable"
@@ -388,6 +393,8 @@ def __init__(
"""
if api_key is None:
api_key = os.environ.get("TOGETHER_API_KEY")
+ if api_key is None and "google.colab" in sys.modules:
+ api_key = get_google_colab_secret("TOGETHER_API_KEY")
if api_key is None:
raise TogetherError(
"The api_key client option must be set either by passing api_key to the client or by setting the TOGETHER_API_KEY environment variable"
diff --git a/src/together/_models.py b/src/together/_models.py
index ca9500b2..29070e05 100644
--- a/src/together/_models.py
+++ b/src/together/_models.py
@@ -3,7 +3,20 @@
import os
import inspect
import weakref
-from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Type,
+ Union,
+ Generic,
+ TypeVar,
+ Callable,
+ Iterable,
+ Optional,
+ AsyncIterable,
+ cast,
+)
from datetime import date, datetime
from typing_extensions import (
List,
@@ -787,6 +800,7 @@ class FinalRequestOptionsInput(TypedDict, total=False):
timeout: float | Timeout | None
files: HttpxRequestFiles | None
idempotency_key: str
+ content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None]
json_data: Body
extra_json: AnyMapping
follow_redirects: bool
@@ -805,6 +819,7 @@ class FinalRequestOptions(pydantic.BaseModel):
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
follow_redirects: Union[bool, None] = None
+ content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None] = None
# It should be noted that we cannot use `json` here as that would override
# a BaseModel method in an incompatible fashion.
json_data: Union[Body, None] = None
diff --git a/src/together/_types.py b/src/together/_types.py
index a39b8518..cf3a156f 100644
--- a/src/together/_types.py
+++ b/src/together/_types.py
@@ -13,9 +13,11 @@
Mapping,
TypeVar,
Callable,
+ Iterable,
Iterator,
Optional,
Sequence,
+ AsyncIterable,
)
from typing_extensions import (
Set,
@@ -56,6 +58,13 @@
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
+
+
+# Used for sending raw binary data / streaming data in request bodies
+# e.g. for file uploads without multipart encoding
+BinaryTypes = Union[bytes, bytearray, IO[bytes], Iterable[bytes]]
+AsyncBinaryTypes = Union[bytes, bytearray, IO[bytes], AsyncIterable[bytes]]
+
FileTypes = Union[
# file (or bytes)
FileContent,
diff --git a/src/together/constants.py b/src/together/constants.py
index 78fafa3b..cd13f59c 100644
--- a/src/together/constants.py
+++ b/src/together/constants.py
@@ -31,4 +31,4 @@
MAX_CONNECTION_RETRIES = 2
MAX_RETRIES = DEFAULT_MAX_RETRIES
-BASE_URL = "https://api.together.xyz/v1"
\ No newline at end of file
+BASE_URL = "https://api.together.xyz/v1"
diff --git a/src/together/error.py b/src/together/error.py
index ebffa4e9..ad3e17ea 100644
--- a/src/together/error.py
+++ b/src/together/error.py
@@ -1,4 +1,3 @@
-
# Manually added to minimize breaking changes from V1
from ._exceptions import (
APIError as APIError,
@@ -13,4 +12,4 @@
Timeout = APITimeoutError
InvalidRequestError = BadRequestError
TogetherException = APIError
-ResponseError = APIResponseValidationError
\ No newline at end of file
+ResponseError = APIResponseValidationError
diff --git a/src/together/lib/_google_colab.py b/src/together/lib/_google_colab.py
new file mode 100644
index 00000000..55026277
--- /dev/null
+++ b/src/together/lib/_google_colab.py
@@ -0,0 +1,39 @@
+import sys
+from typing import Union
+
+from together.lib.utils._log import log_info
+
+
+def get_google_colab_secret(secret_name: str = "TOGETHER_API_KEY") -> Union[str, None]:
+ """
+ Checks to see if the user is running in Google Colab, and looks for the Together API Key secret.
+
+ Args:
+ secret_name (str, optional). Defaults to TOGETHER_API_KEY
+
+ Returns:
+ str: if the API key is found; None if an error occurred or the secret was not found.
+ """
+ # If running in Google Colab, check for Together in notebook secrets
+ if "google.colab" in sys.modules:
+ from google.colab import userdata # type: ignore
+
+ try:
+ api_key = userdata.get(secret_name) # type: ignore
+ if not isinstance(api_key, str):
+ return None
+ else:
+ return str(api_key)
+ except userdata.NotebookAccessError: # type: ignore
+ log_info(
+ "The TOGETHER_API_KEY Colab secret was found, but notebook access is disabled. Please enable notebook "
+ "access for the secret."
+ )
+ except userdata.SecretNotFoundError: # type: ignore
+ # warn and carry on
+ log_info("Colab: No Google Colab secret named TOGETHER_API_KEY was found.")
+
+ return None
+
+ else:
+ return None
diff --git a/src/together/lib/cli/_track_cli.py b/src/together/lib/cli/_track_cli.py
new file mode 100644
index 00000000..4886b418
--- /dev/null
+++ b/src/together/lib/cli/_track_cli.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+import os
+import json
+import time
+import uuid
+import threading
+from enum import Enum
+from typing import Any, TypeVar, Callable
+from functools import wraps
+
+import click
+import httpx
+import machineid
+
+from together import __version__
+from together.lib.utils import log_debug
+
+F = TypeVar("F", bound=Callable[..., Any])
+
+SESSION_ID = int(str(uuid.uuid4().int)[0:13])
+
+
+def is_tracking_enabled() -> bool:
+ # Users can opt-out of tracking with the environment variable.
+ if os.getenv("TOGETHER_TELEMETRY_DISABLED"):
+ log_debug("Analytics tracking disabled by environment variable")
+ return False
+
+ return True
+
+
+class CliTrackingEvents(Enum):
+ CommandStarted = "cli_command_started"
+ CommandCompleted = "cli_commmand_completed"
+ CommandFailed = "cli_command_failed"
+ CommandUserAborted = "cli_command_user_aborted"
+ ApiRequest = "cli_command_api_request"
+
+
+def track_cli(event_name: CliTrackingEvents, args: dict[str, Any]) -> None:
+ """Track a CLI event. Non-Blocking."""
+ if is_tracking_enabled() == False:
+ return
+
+ def send_event() -> None:
+ ANALYTICS_API_ENV_VAR = os.getenv("TOGETHER_TELEMETRY_API")
+ ANALYTICS_API = (
+ ANALYTICS_API_ENV_VAR if ANALYTICS_API_ENV_VAR else "https://api.together.ai/api/together-cli-events"
+ )
+
+ try:
+ client = httpx.Client()
+ client.post(
+ ANALYTICS_API,
+ headers={"content-type": "application/json", "user-agent": f"together-cli:{__version__}"},
+ content=json.dumps(
+ {
+ "event_name": event_name.value,
+ "event_properties": {
+ "is_ci": os.getenv("CI") is not None,
+ **args,
+ },
+ "event_options": {
+ "time": int(time.time() * 1000),
+ "session_id": str(SESSION_ID),
+ "device_id": machineid.id().lower(),
+ },
+ }
+ ),
+ )
+ except Exception as e:
+ log_debug("Error sending analytics event", error=e)
+ # No-op - this is not critical and we don't want to block the CLI
+ pass
+
+ threading.Thread(target=send_event).start()
+
+
+def auto_track_command(command: str) -> Callable[[F], F]:
+ """Decorator for click commands to automatically track CLI commands start/completion/failure."""
+
+ def decorator(f: F) -> F:
+ @wraps(f)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ track_cli(CliTrackingEvents.CommandStarted, {"command": command, "arguments": kwargs})
+ try:
+ return f(*args, **kwargs)
+ except click.Abort:
+ # Doesn't seem like this is working any more
+ track_cli(
+ CliTrackingEvents.CommandUserAborted,
+ {"command": command, "arguments": kwargs},
+ )
+ except Exception as e:
+ track_cli(CliTrackingEvents.CommandFailed, {"command": command, "arguments": kwargs, "error": str(e)})
+ raise e
+ finally:
+ track_cli(CliTrackingEvents.CommandCompleted, {"command": command, "arguments": kwargs})
+
+ return wrapper # type: ignore
+
+ return decorator # type: ignore
diff --git a/src/together/lib/cli/api/beta/clusters.py b/src/together/lib/cli/api/beta/clusters.py
index 038cc571..fbb42231 100644
--- a/src/together/lib/cli/api/beta/clusters.py
+++ b/src/together/lib/cli/api/beta/clusters.py
@@ -12,6 +12,7 @@
from together._response import APIResponse as APIResponse
from together.types.beta import Cluster, ClusterCreateParams
from together.lib.cli.api.utils import handle_api_errors
+from together.lib.cli._track_cli import auto_track_command
from together.types.beta.cluster_create_params import SharedVolume
from together.lib.cli.api.beta.clusters_storage import storage
@@ -47,6 +48,7 @@ def clusters(ctx: click.Context) -> None:
help="Output in JSON format",
)
@click.pass_context
+@auto_track_command("clusters list")
def list(ctx: click.Context, json: bool) -> None:
"""List clusters"""
client: Together = ctx.obj
@@ -108,6 +110,7 @@ def list(ctx: click.Context, json: bool) -> None:
)
@click.pass_context
@handle_api_errors("Clusters")
+@auto_track_command("clusters create")
def create(
ctx: click.Context,
name: str | None = None,
@@ -239,6 +242,7 @@ def create(
)
@click.pass_context
@handle_api_errors("Clusters")
+@auto_track_command("clusters retrieve")
def retrieve(ctx: click.Context, cluster_id: str, json: bool) -> None:
"""Retrieve a cluster by ID"""
client: Together = ctx.obj
@@ -273,6 +277,7 @@ def retrieve(ctx: click.Context, cluster_id: str, json: bool) -> None:
)
@click.pass_context
@handle_api_errors("Clusters")
+@auto_track_command("clusters update")
def update(
ctx: click.Context,
cluster_id: str,
@@ -308,6 +313,7 @@ def update(
)
@click.pass_context
@handle_api_errors("Clusters")
+@auto_track_command("clusters delete")
def delete(ctx: click.Context, cluster_id: str, json: bool) -> None:
"""Delete a cluster by ID"""
client: Together = ctx.obj
@@ -336,6 +342,7 @@ def delete(ctx: click.Context, cluster_id: str, json: bool) -> None:
)
@click.pass_context
@handle_api_errors("Clusters")
+@auto_track_command("clusters list-regions")
def list_regions(ctx: click.Context, json: bool) -> None:
"""List regions"""
client: Together = ctx.obj
diff --git a/src/together/lib/cli/api/beta/clusters_storage.py b/src/together/lib/cli/api/beta/clusters_storage.py
index 0a3214fa..2c25b593 100644
--- a/src/together/lib/cli/api/beta/clusters_storage.py
+++ b/src/together/lib/cli/api/beta/clusters_storage.py
@@ -7,6 +7,7 @@
from together import Together
from together.lib.cli.api.utils import handle_api_errors
+from together.lib.cli._track_cli import auto_track_command
from together.types.beta.clusters import ClusterStorage
@@ -56,6 +57,7 @@ def storage(ctx: click.Context) -> None:
)
@click.pass_context
@handle_api_errors("Clusters Storage")
+@auto_track_command("clusters storage create")
def create(ctx: click.Context, region: str, size_tib: int, volume_name: str, json: bool) -> None:
"""Create a storage volume"""
client: Together = ctx.obj
@@ -85,6 +87,7 @@ def create(ctx: click.Context, region: str, size_tib: int, volume_name: str, jso
)
@click.pass_context
@handle_api_errors("Clusters Storage")
+@auto_track_command("clusters storage retrieve")
def retrieve(ctx: click.Context, volume_id: str, json: bool) -> None:
"""Retrieve a storage volume"""
client: Together = ctx.obj
@@ -112,6 +115,7 @@ def retrieve(ctx: click.Context, volume_id: str, json: bool) -> None:
)
@click.pass_context
@handle_api_errors("Clusters Storage")
+@auto_track_command("clusters storage delete")
def delete(ctx: click.Context, volume_id: str, json: bool) -> None:
"""Delete a storage volume"""
client: Together = ctx.obj
@@ -140,6 +144,7 @@ def delete(ctx: click.Context, volume_id: str, json: bool) -> None:
)
@click.pass_context
@handle_api_errors("Clusters Storage")
+@auto_track_command("clusters storage list")
def list(ctx: click.Context, json: bool) -> None:
"""List storage volumes"""
client: Together = ctx.obj
diff --git a/src/together/lib/cli/api/endpoints.py b/src/together/lib/cli/api/endpoints.py
index 8383f629..205a5967 100644
--- a/src/together/lib/cli/api/endpoints.py
+++ b/src/together/lib/cli/api/endpoints.py
@@ -10,6 +10,7 @@
from together import Together, omit
from together.types import DedicatedEndpoint
from together._exceptions import APIError
+from together.lib.cli._track_cli import auto_track_command
from together.lib.utils.serializer import datetime_serializer
from together.types.endpoint_list_response import Data as DedicatedEndpointListItem
@@ -134,6 +135,7 @@ def endpoints(ctx: click.Context) -> None:
)
@click.pass_obj
@handle_api_errors
+@auto_track_command("endpoints create")
def create(
client: Together,
model: str,
@@ -221,6 +223,7 @@ def create(
@click.option("--json", is_flag=True, help="Print output in JSON format")
@click.pass_obj
@handle_api_errors
+@auto_track_command("endpoints get")
def get(client: Together, endpoint_id: str, json: bool) -> None:
"""Get a dedicated inference endpoint."""
endpoint = client.endpoints.retrieve(endpoint_id)
@@ -242,6 +245,7 @@ def get(client: Together, endpoint_id: str, json: bool) -> None:
)
@click.pass_obj
@handle_api_errors
+@auto_track_command("endpoints hardware")
def hardware(client: Together, model: str | None, json: bool, available: bool) -> None:
"""List all available hardware options, optionally filtered by model."""
fetch_and_print_hardware_options(client, model, json, available)
@@ -274,6 +278,7 @@ def fetch_and_print_hardware_options(client: Together, model: str | None, print_
@click.option("--wait/--no-wait", default=True, help="Wait for the endpoint to stop")
@click.pass_obj
@handle_api_errors
+@auto_track_command("endpoints stop")
def stop(client: Together, endpoint_id: str, wait: bool) -> None:
"""Stop a dedicated inference endpoint."""
client.endpoints.update(endpoint_id, state="STOPPED")
@@ -295,6 +300,7 @@ def stop(client: Together, endpoint_id: str, wait: bool) -> None:
@click.option("--wait/--no-wait", default=True, help="Wait for the endpoint to start")
@click.pass_obj
@handle_api_errors
+@auto_track_command("endpoints start")
def start(client: Together, endpoint_id: str, wait: bool) -> None:
"""Start a dedicated inference endpoint."""
client.endpoints.update(endpoint_id, state="STARTED")
@@ -315,6 +321,7 @@ def start(client: Together, endpoint_id: str, wait: bool) -> None:
@click.argument("endpoint-id", required=True)
@click.pass_obj
@handle_api_errors
+@auto_track_command("endpoints delete")
def delete(client: Together, endpoint_id: str) -> None:
"""Delete a dedicated inference endpoint."""
client.endpoints.delete(endpoint_id)
@@ -342,6 +349,7 @@ def delete(client: Together, endpoint_id: str) -> None:
)
@click.pass_obj
@handle_api_errors
+@auto_track_command("endpoints list")
def list(
client: Together,
json: bool,
@@ -400,6 +408,7 @@ def list(
)
@click.pass_obj
@handle_api_errors
+@auto_track_command("endpoints update")
def update(
client: Together,
endpoint_id: str,
@@ -449,6 +458,7 @@ def update(
@click.option("--json", is_flag=True, help="Print output in JSON format")
@click.pass_obj
@handle_api_errors
+@auto_track_command("endpoints availability-zones")
def availability_zones(client: Together, json: bool) -> None:
"""List all availability zones."""
avzones = client.endpoints.list_avzones()
diff --git a/src/together/lib/cli/api/evals.py b/src/together/lib/cli/api/evals.py
index 8eb4dba6..807dc0a5 100644
--- a/src/together/lib/cli/api/evals.py
+++ b/src/together/lib/cli/api/evals.py
@@ -8,6 +8,7 @@
from together import APIError, Together, TogetherError
from together._types import omit
+from together.lib.cli._track_cli import auto_track_command
from together.lib.utils.serializer import datetime_serializer
from together.types.eval_create_params import (
ParametersEvaluationScoreParameters,
@@ -274,6 +275,7 @@ def evals(ctx: click.Context) -> None:
help="Input template for model B.",
)
@handle_api_errors
+@auto_track_command("evals create")
def create(
ctx: click.Context,
type: Literal["classify", "score", "compare"],
@@ -489,6 +491,7 @@ def create(
type=int,
help="Limit number of results (max 100).",
)
+@auto_track_command("evals list")
def list(
ctx: click.Context,
status: Union[Literal["pending", "queued", "running", "completed", "error", "user_error"], None],
@@ -530,6 +533,7 @@ def list(
@evals.command()
@click.pass_context
@click.argument("evaluation_id", type=str, required=True)
+@auto_track_command("evals retrieve")
def retrieve(ctx: click.Context, evaluation_id: str) -> None:
"""Get details of a specific evaluation job"""
@@ -543,6 +547,7 @@ def retrieve(ctx: click.Context, evaluation_id: str) -> None:
@evals.command()
@click.pass_context
@click.argument("evaluation_id", type=str, required=True)
+@auto_track_command("evals status")
def status(ctx: click.Context, evaluation_id: str) -> None:
"""Get the status and results of a specific evaluation job"""
diff --git a/src/together/lib/cli/api/files.py b/src/together/lib/cli/api/files.py
index 9d716c71..d60ea691 100644
--- a/src/together/lib/cli/api/files.py
+++ b/src/together/lib/cli/api/files.py
@@ -8,6 +8,7 @@
from together import Together
from together.types import FilePurpose
+from together.lib.cli._track_cli import auto_track_command
# from together.utils import check_file, convert_bytes, convert_unix_timestamp
from ...utils import check_file, convert_bytes, convert_unix_timestamp
@@ -38,6 +39,7 @@ def files(ctx: click.Context) -> None:
default=True,
help="Whether to check the file before uploading.",
)
+@auto_track_command("files upload")
def upload(ctx: click.Context, file: pathlib.Path, purpose: FilePurpose, check: bool) -> None:
"""Upload file"""
@@ -50,6 +52,7 @@ def upload(ctx: click.Context, file: pathlib.Path, purpose: FilePurpose, check:
@files.command()
@click.pass_context
+@auto_track_command("files list")
def list(ctx: click.Context) -> None:
"""List files"""
client: Together = ctx.obj
@@ -75,6 +78,7 @@ def list(ctx: click.Context) -> None:
@files.command()
@click.pass_context
@click.argument("id", type=str, required=True)
+@auto_track_command("files retrieve")
def retrieve(ctx: click.Context, id: str) -> None:
"""Upload file"""
@@ -89,6 +93,7 @@ def retrieve(ctx: click.Context, id: str) -> None:
@click.pass_context
@click.argument("id", type=str, required=True)
@click.option("--output", type=str, default=None, help="Output filename")
+@auto_track_command("files retrieve-content")
def retrieve_content(ctx: click.Context, id: str, output: str) -> None:
"""Retrieve file content and output to file"""
@@ -108,6 +113,7 @@ def retrieve_content(ctx: click.Context, id: str, output: str) -> None:
@files.command()
@click.pass_context
@click.argument("id", type=str, required=True)
+@auto_track_command("files delete")
def delete(ctx: click.Context, id: str) -> None:
"""Delete remote file"""
@@ -125,6 +131,7 @@ def delete(ctx: click.Context, id: str) -> None:
type=click.Path(exists=True, file_okay=True, resolve_path=True, readable=True, dir_okay=False),
required=True,
)
+@auto_track_command("files check")
def check(_ctx: click.Context, file: pathlib.Path) -> None:
"""Check file for issues"""
diff --git a/src/together/lib/cli/api/fine_tuning.py b/src/together/lib/cli/api/fine_tuning.py
index 549fe569..3efc8087 100644
--- a/src/together/lib/cli/api/fine_tuning.py
+++ b/src/together/lib/cli/api/fine_tuning.py
@@ -19,6 +19,7 @@
from together.lib.utils import log_warn
from together.lib.utils.tools import format_timestamp, finetune_price_to_dollars
from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO, generate_progress_bar
+from together.lib.cli._track_cli import auto_track_command
from together.lib.resources.files import DownloadManager
from together.lib.utils.serializer import datetime_serializer
from together.types.finetune_response import TrainingTypeFullTrainingType, TrainingTypeLoRaTrainingType
@@ -216,6 +217,7 @@ def fine_tuning(ctx: click.Context) -> None:
default=None,
help="HF repo to upload the fine-tuned model to",
)
+@auto_track_command("fine-tuning create")
def create(
ctx: click.Context,
training_file: str,
@@ -415,6 +417,7 @@ def create(
@fine_tuning.command()
@click.pass_context
+@auto_track_command("fine-tuning list")
def list(ctx: click.Context) -> None:
"""List fine-tuning jobs"""
client: Together = ctx.obj
@@ -449,6 +452,7 @@ def list(ctx: click.Context) -> None:
@fine_tuning.command()
@click.pass_context
@click.argument("fine_tune_id", type=str, required=True)
+@auto_track_command("fine-tuning retrieve")
def retrieve(ctx: click.Context, fine_tune_id: str) -> None:
"""Retrieve fine-tuning job details"""
client: Together = ctx.obj
@@ -468,6 +472,7 @@ def retrieve(ctx: click.Context, fine_tune_id: str) -> None:
@click.pass_context
@click.argument("fine_tune_id", type=str, required=True)
@click.option("--quiet", is_flag=True, help="Do not prompt for confirmation before cancelling job")
+@auto_track_command("fine-tuning cancel")
def cancel(ctx: click.Context, fine_tune_id: str, quiet: bool = False) -> None:
"""Cancel fine-tuning job"""
client: Together = ctx.obj
@@ -487,6 +492,7 @@ def cancel(ctx: click.Context, fine_tune_id: str, quiet: bool = False) -> None:
@fine_tuning.command()
@click.pass_context
@click.argument("fine_tune_id", type=str, required=True)
+@auto_track_command("fine-tuning list-events")
def list_events(ctx: click.Context, fine_tune_id: str) -> None:
"""List fine-tuning events"""
client: Together = ctx.obj
@@ -513,6 +519,7 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
@fine_tuning.command()
@click.pass_context
@click.argument("fine_tune_id", type=str, required=True)
+@auto_track_command("fine-tuning list-checkpoints")
def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None:
"""List available checkpoints for a fine-tuning job"""
client: Together = ctx.obj
@@ -569,6 +576,7 @@ def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None:
default="merged",
help="Specifies checkpoint type. 'merged' and 'adapter' options work only for LoRA jobs.",
)
+@auto_track_command("fine-tuning download")
def download(
ctx: click.Context,
fine_tune_id: str,
@@ -628,6 +636,7 @@ def download(
@click.argument("fine_tune_id", type=str, required=True)
@click.option("--force", is_flag=True, help="Force deletion without confirmation")
@click.option("--quiet", is_flag=True, help="Do not prompt for confirmation before deleting job")
+@auto_track_command("fine-tuning delete")
def delete(ctx: click.Context, fine_tune_id: str, force: bool = False, quiet: bool = False) -> None:
"""Delete fine-tuning job"""
client: Together = ctx.obj
diff --git a/src/together/lib/cli/api/models.py b/src/together/lib/cli/api/models.py
index cb566695..737a4d31 100644
--- a/src/together/lib/cli/api/models.py
+++ b/src/together/lib/cli/api/models.py
@@ -7,6 +7,7 @@
from together import Together, omit
from together._models import BaseModel
from together._response import APIResponse as APIResponse
+from together.lib.cli._track_cli import auto_track_command
from together.types.model_upload_response import ModelUploadResponse
@@ -29,12 +30,16 @@ def models(ctx: click.Context) -> None:
help="Output in JSON format",
)
@click.pass_context
+@auto_track_command("models list")
def list(ctx: click.Context, type: Optional[str], json: bool) -> None:
"""List models"""
client: Together = ctx.obj
models_list = client.models.list(dedicated=type == "dedicated" if type else omit)
+ if json:
+ click.echo(json_lib.dumps(models_list, indent=2))
+
display_list: List[Dict[str, Any]] = []
model: BaseModel
for model in models_list:
@@ -51,10 +56,7 @@ def list(ctx: click.Context, type: Optional[str], json: bool) -> None:
}
)
- if json:
- click.echo(json_lib.dumps(display_list, indent=2))
- else:
- click.echo(tabulate(display_list, headers="keys", tablefmt="plain"))
+ click.echo(tabulate(display_list, headers="keys", tablefmt="plain"))
@models.command()
@@ -96,6 +98,7 @@ def list(ctx: click.Context, type: Optional[str], json: bool) -> None:
help="Output in JSON format",
)
@click.pass_context
+@auto_track_command("models upload")
def upload(
ctx: click.Context,
model_name: str,
diff --git a/src/together/lib/cli/cli.py b/src/together/lib/cli/cli.py
index c3f01d49..f549b042 100644
--- a/src/together/lib/cli/cli.py
+++ b/src/together/lib/cli/cli.py
@@ -4,12 +4,14 @@
from typing import Any
import click
+import httpx
import together
from together._version import __version__
from together._constants import DEFAULT_TIMEOUT
from together.lib.cli.api.evals import evals
from together.lib.cli.api.files import files
+from together.lib.cli._track_cli import CliTrackingEvents, track_cli
from together.lib.cli.api.models import models
from together.lib.cli.api.beta.beta import beta
from together.lib.cli.api.endpoints import endpoints
@@ -57,10 +59,22 @@ def main(
) -> None:
"""This is a sample CLI tool."""
os.environ.setdefault("TOGETHER_LOG", "debug" if debug else "info")
- ctx.obj = together.Together(
+
+ client = together.Together(
api_key=api_key, base_url=base_url, timeout=timeout, max_retries=max_retries if max_retries is not None else 0
)
+ # Wrap the client's httpx requests to track the parameters sent on api requests
+ def track_request(request: httpx.Request) -> None:
+ track_cli(
+ CliTrackingEvents.ApiRequest,
+ {"url": str(request.url), "method": request.method, "body": request.content.decode("utf-8")},
+ )
+
+ client._client.event_hooks["request"].append(track_request)
+
+ ctx.obj = client
+
main.add_command(files)
main.add_command(fine_tuning)
diff --git a/src/together/resources/beta/clusters/clusters.py b/src/together/resources/beta/clusters/clusters.py
index d18f215c..c3e89a8c 100644
--- a/src/together/resources/beta/clusters/clusters.py
+++ b/src/together/resources/beta/clusters/clusters.py
@@ -80,8 +80,13 @@ def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterCreateResponse:
- """
- Create GPU Cluster
+ """Create an Instant Cluster on Together's high-performance GPU clusters.
+
+ With
+ features like on-demand scaling, long-lived resizable high-bandwidth shared
+ DC-local storage, Kubernetes and Slurm cluster flavors, a REST API, and
+ Terraform support, you can run workloads flexibly without complex infrastructure
+ management.
Args:
cluster_name: Name of the GPU cluster.
@@ -141,7 +146,7 @@ def retrieve(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Cluster:
"""
- Get GPU cluster by cluster ID
+ Retrieve information about a specific GPU cluster.
Args:
extra_headers: Send extra headers
@@ -176,7 +181,7 @@ def update(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterUpdateResponse:
"""
- Update a GPU Cluster.
+ Update the configuration of an existing GPU cluster.
Args:
extra_headers: Send extra headers
@@ -235,7 +240,7 @@ def delete(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterDeleteResponse:
"""
- Delete GPU cluster by cluster ID
+ Delete a GPU cluster by cluster ID.
Args:
extra_headers: Send extra headers
@@ -320,8 +325,13 @@ async def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterCreateResponse:
- """
- Create GPU Cluster
+ """Create an Instant Cluster on Together's high-performance GPU clusters.
+
+ With
+ features like on-demand scaling, long-lived resizable high-bandwidth shared
+ DC-local storage, Kubernetes and Slurm cluster flavors, a REST API, and
+ Terraform support, you can run workloads flexibly without complex infrastructure
+ management.
Args:
cluster_name: Name of the GPU cluster.
@@ -381,7 +391,7 @@ async def retrieve(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> Cluster:
"""
- Get GPU cluster by cluster ID
+ Retrieve information about a specific GPU cluster.
Args:
extra_headers: Send extra headers
@@ -416,7 +426,7 @@ async def update(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterUpdateResponse:
"""
- Update a GPU Cluster.
+ Update the configuration of an existing GPU cluster.
Args:
extra_headers: Send extra headers
@@ -475,7 +485,7 @@ async def delete(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterDeleteResponse:
"""
- Delete GPU cluster by cluster ID
+ Delete a GPU cluster by cluster ID.
Args:
extra_headers: Send extra headers
diff --git a/src/together/resources/beta/clusters/storage.py b/src/together/resources/beta/clusters/storage.py
index c7915737..508f9471 100644
--- a/src/together/resources/beta/clusters/storage.py
+++ b/src/together/resources/beta/clusters/storage.py
@@ -57,12 +57,15 @@ def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> StorageCreateResponse:
- """Create a shared volume.
+ """
+ Instant Clusters supports long-lived, resizable in-DC shared storage with user
+ data persistence. You can dynamically create and attach volumes to your cluster
+ at cluster creation time, and resize as your data grows. All shared storage is
+ backed by multi-NIC bare metal paths, ensuring high-throughput and low-latency
+ performance for shared storage.
Args:
- region: Region name.
-
- Usable regions can be found from `client.clusters.list_regions()`
+ region: Region name. Usable regions can be found from `client.clusters.list_regions()`
size_tib: Volume size in whole tebibytes (TiB).
@@ -102,7 +105,7 @@ def retrieve(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterStorage:
"""
- Get shared volume by volume Id.
+ Retrieve information about a specific shared volume.
Args:
extra_headers: Send extra headers
@@ -136,7 +139,7 @@ def update(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterStorage:
"""
- Update a shared volume.
+ Update the configuration of an existing shared volume.
Args:
extra_headers: Send extra headers
@@ -192,8 +195,10 @@ def delete(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> StorageDeleteResponse:
- """
- Delete shared volume by volume id.
+ """Delete a shared volume.
+
+ Note that if this volume is attached to a cluster,
+ deleting will fail.
Args:
extra_headers: Send extra headers
@@ -248,12 +253,15 @@ async def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> StorageCreateResponse:
- """Create a shared volume.
+ """
+ Instant Clusters supports long-lived, resizable in-DC shared storage with user
+ data persistence. You can dynamically create and attach volumes to your cluster
+ at cluster creation time, and resize as your data grows. All shared storage is
+ backed by multi-NIC bare metal paths, ensuring high-throughput and low-latency
+ performance for shared storage.
Args:
- region: Region name.
-
- Usable regions can be found from `client.clusters.list_regions()`
+ region: Region name. Usable regions can be found from `client.clusters.list_regions()`
size_tib: Volume size in whole tebibytes (TiB).
@@ -293,7 +301,7 @@ async def retrieve(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterStorage:
"""
- Get shared volume by volume Id.
+ Retrieve information about a specific shared volume.
Args:
extra_headers: Send extra headers
@@ -327,7 +335,7 @@ async def update(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> ClusterStorage:
"""
- Update a shared volume.
+ Update the configuration of an existing shared volume.
Args:
extra_headers: Send extra headers
@@ -383,8 +391,10 @@ async def delete(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> StorageDeleteResponse:
- """
- Delete shared volume by volume id.
+ """Delete a shared volume.
+
+ Note that if this volume is attached to a cluster,
+ deleting will fail.
Args:
extra_headers: Send extra headers
diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py
index 0aa7ff03..645b89dd 100644
--- a/src/together/types/__init__.py
+++ b/src/together/types/__init__.py
@@ -1,6 +1,6 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-from __future__ import annotations # noqa
+from __future__ import annotations # noqa
from .batch_job import BatchJob as BatchJob
from .embedding import Embedding as Embedding
@@ -70,12 +70,13 @@
)
# Manually added to minimize breaking changes from V1
-from .chat.chat_completion import ChatCompletion
+from .chat.chat_completion import ChatCompletion
from .chat.chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunk
from .chat.chat_completion_usage import ChatCompletionUsage
+
UsageData = ChatCompletionUsage
ChatCompletionResponse = ChatCompletion
CompletionResponse = Completion
ListEndpoint = EndpointListResponse
ImageRequest = ImageGenerateParams
-ImageResponse = ImageFile
\ No newline at end of file
+ImageResponse = ImageFile
diff --git a/src/together/types/chat_completions.py b/src/together/types/chat_completions.py
index 956a6fda..dfa7f175 100644
--- a/src/together/types/chat_completions.py
+++ b/src/together/types/chat_completions.py
@@ -4,4 +4,4 @@
from .chat.chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunk
ChatCompletionResponse = ChatCompletion
-ToolCalls = ToolChoice
\ No newline at end of file
+ToolCalls = ToolChoice
diff --git a/src/together/types/endpoints.py b/src/together/types/endpoints.py
index b2e49e6d..808df6ca 100644
--- a/src/together/types/endpoints.py
+++ b/src/together/types/endpoints.py
@@ -1,4 +1,4 @@
# Manually added to minimize breaking changes from V1
from together.types import DedicatedEndpoint as DedicatedEndpoint
-ListEndpoint = DedicatedEndpoint
\ No newline at end of file
+ListEndpoint = DedicatedEndpoint
diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py
index 8d371cac..a3322f1d 100644
--- a/tests/test_cli_utils.py
+++ b/tests/test_cli_utils.py
@@ -32,7 +32,7 @@ def create_finetune_response(
return FinetuneResponse(
id=job_id,
progress=progress,
- updated_at=started_at, # to calm down mypy
+ updated_at=started_at, # to calm down mypy
started_at=started_at,
status=status,
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
diff --git a/tests/test_client.py b/tests/test_client.py
index 579f2f6b..011fef4b 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -8,10 +8,11 @@
import json
import asyncio
import inspect
+import dataclasses
import tracemalloc
-from typing import Any, Union, cast
+from typing import Any, Union, TypeVar, Callable, Iterable, Iterator, Optional, Coroutine, cast
from unittest import mock
-from typing_extensions import Literal
+from typing_extensions import Literal, AsyncIterator, override
import httpx
import pytest
@@ -37,6 +38,7 @@
from .utils import update_env
+T = TypeVar("T")
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = "My API Key"
@@ -51,6 +53,57 @@ def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
return 0.1
+def mirror_request_content(request: httpx.Request) -> httpx.Response:
+ return httpx.Response(200, content=request.content)
+
+
+# note: we can't use the httpx.MockTransport class as it consumes the request
+# body itself, which means we can't test that the body is read lazily
+class MockTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
+ def __init__(
+ self,
+ handler: Callable[[httpx.Request], httpx.Response]
+ | Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]],
+ ) -> None:
+ self.handler = handler
+
+ @override
+ def handle_request(
+ self,
+ request: httpx.Request,
+ ) -> httpx.Response:
+ assert not inspect.iscoroutinefunction(self.handler), "handler must not be a coroutine function"
+ assert inspect.isfunction(self.handler), "handler must be a function"
+ return self.handler(request)
+
+ @override
+ async def handle_async_request(
+ self,
+ request: httpx.Request,
+ ) -> httpx.Response:
+ assert inspect.iscoroutinefunction(self.handler), "handler must be a coroutine function"
+ return await self.handler(request)
+
+
+@dataclasses.dataclass
+class Counter:
+ value: int = 0
+
+
+def _make_sync_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> Iterator[T]:
+ for item in iterable:
+ if counter:
+ counter.value += 1
+ yield item
+
+
+async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> AsyncIterator[T]:
+ for item in iterable:
+ if counter:
+ counter.value += 1
+ yield item
+
+
def _get_open_connections(client: Together | AsyncTogether) -> int:
transport = client._client._transport
assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
@@ -503,6 +556,70 @@ def test_multipart_repeating_array(self, client: Together) -> None:
b"",
]
+ @pytest.mark.respx(base_url=base_url)
+ def test_binary_content_upload(self, respx_mock: MockRouter, client: Together) -> None:
+ respx_mock.post("/upload").mock(side_effect=mirror_request_content)
+
+ file_content = b"Hello, this is a test file."
+
+ response = client.post(
+ "/upload",
+ content=file_content,
+ cast_to=httpx.Response,
+ options={"headers": {"Content-Type": "application/octet-stream"}},
+ )
+
+ assert response.status_code == 200
+ assert response.request.headers["Content-Type"] == "application/octet-stream"
+ assert response.content == file_content
+
+ def test_binary_content_upload_with_iterator(self) -> None:
+ file_content = b"Hello, this is a test file."
+ counter = Counter()
+ iterator = _make_sync_iterator([file_content], counter=counter)
+
+ def mock_handler(request: httpx.Request) -> httpx.Response:
+ assert counter.value == 0, "the request body should not have been read"
+ return httpx.Response(200, content=request.read())
+
+ with Together(
+ base_url=base_url,
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=httpx.Client(transport=MockTransport(handler=mock_handler)),
+ ) as client:
+ response = client.post(
+ "/upload",
+ content=iterator,
+ cast_to=httpx.Response,
+ options={"headers": {"Content-Type": "application/octet-stream"}},
+ )
+
+ assert response.status_code == 200
+ assert response.request.headers["Content-Type"] == "application/octet-stream"
+ assert response.content == file_content
+ assert counter.value == 1
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_binary_content_upload_with_body_is_deprecated(self, respx_mock: MockRouter, client: Together) -> None:
+ respx_mock.post("/upload").mock(side_effect=mirror_request_content)
+
+ file_content = b"Hello, this is a test file."
+
+ with pytest.deprecated_call(
+ match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
+ ):
+ response = client.post(
+ "/upload",
+ body=file_content,
+ cast_to=httpx.Response,
+ options={"headers": {"Content-Type": "application/octet-stream"}},
+ )
+
+ assert response.status_code == 200
+ assert response.request.headers["Content-Type"] == "application/octet-stream"
+ assert response.content == file_content
+
@pytest.mark.respx(base_url=base_url)
def test_basic_union_response(self, respx_mock: MockRouter, client: Together) -> None:
class Model1(BaseModel):
@@ -1375,6 +1492,72 @@ def test_multipart_repeating_array(self, async_client: AsyncTogether) -> None:
b"",
]
+ @pytest.mark.respx(base_url=base_url)
+ async def test_binary_content_upload(self, respx_mock: MockRouter, async_client: AsyncTogether) -> None:
+ respx_mock.post("/upload").mock(side_effect=mirror_request_content)
+
+ file_content = b"Hello, this is a test file."
+
+ response = await async_client.post(
+ "/upload",
+ content=file_content,
+ cast_to=httpx.Response,
+ options={"headers": {"Content-Type": "application/octet-stream"}},
+ )
+
+ assert response.status_code == 200
+ assert response.request.headers["Content-Type"] == "application/octet-stream"
+ assert response.content == file_content
+
+ async def test_binary_content_upload_with_asynciterator(self) -> None:
+ file_content = b"Hello, this is a test file."
+ counter = Counter()
+ iterator = _make_async_iterator([file_content], counter=counter)
+
+ async def mock_handler(request: httpx.Request) -> httpx.Response:
+ assert counter.value == 0, "the request body should not have been read"
+ return httpx.Response(200, content=await request.aread())
+
+ async with AsyncTogether(
+ base_url=base_url,
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=httpx.AsyncClient(transport=MockTransport(handler=mock_handler)),
+ ) as client:
+ response = await client.post(
+ "/upload",
+ content=iterator,
+ cast_to=httpx.Response,
+ options={"headers": {"Content-Type": "application/octet-stream"}},
+ )
+
+ assert response.status_code == 200
+ assert response.request.headers["Content-Type"] == "application/octet-stream"
+ assert response.content == file_content
+ assert counter.value == 1
+
+ @pytest.mark.respx(base_url=base_url)
+ async def test_binary_content_upload_with_body_is_deprecated(
+ self, respx_mock: MockRouter, async_client: AsyncTogether
+ ) -> None:
+ respx_mock.post("/upload").mock(side_effect=mirror_request_content)
+
+ file_content = b"Hello, this is a test file."
+
+ with pytest.deprecated_call(
+ match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
+ ):
+ response = await async_client.post(
+ "/upload",
+ body=file_content,
+ cast_to=httpx.Response,
+ options={"headers": {"Content-Type": "application/octet-stream"}},
+ )
+
+ assert response.status_code == 200
+ assert response.request.headers["Content-Type"] == "application/octet-stream"
+ assert response.content == file_content
+
@pytest.mark.respx(base_url=base_url)
async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncTogether) -> None:
class Model1(BaseModel):
diff --git a/uv.lock b/uv.lock
index ffc03b56..6c3569f8 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1213,6 +1213,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5b/5a/bc7b4a4ef808fa59a816c17b20c4bef6884daebbdf627ff2a161da67da19/propcache-0.4.1-py3-none-any.whl", hash = "sha256:af2a6052aeb6cf17d3e46ee169099044fd8224cbaf75c76a2ef596e8163e2237", size = 13305, upload-time = "2025-10-08T19:49:00.792Z" },
]
+[[package]]
+name = "py-machineid"
+version = "1.0.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "winregistry", marker = "sys_platform == 'win32' or (extra == 'group-8-together-pydantic-v1' and extra == 'group-8-together-pydantic-v2')" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/f4/b0/c7fa6de7298a8f4e544929b97fa028304c0e11a4bc9500eff8689821bdbb/py_machineid-1.0.0.tar.gz", hash = "sha256:8a902a00fae8c6d6433f463697c21dc4ce98c6e55a2e0535c0273319acb0047a", size = 4629, upload-time = "2025-12-02T16:12:54.286Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/be/76/1ed8375cb1212824c57eb706e1f09f3f2ca4ed12b8d56b28a160e2d53505/py_machineid-1.0.0-py3-none-any.whl", hash = "sha256:910df0d5f2663bcf6739d835c4949f4e9cc6bb090a58b3dd766e12e5f768e3b9", size = 4926, upload-time = "2025-12-02T16:12:20.584Z" },
+]
+
[[package]]
name = "pyarrow"
version = "21.0.0"
@@ -1963,7 +1975,7 @@ wheels = [
[[package]]
name = "together"
-version = "2.0.0a14"
+version = "2.0.0a15"
source = { editable = "." }
dependencies = [
{ name = "anyio" },
@@ -1975,6 +1987,7 @@ dependencies = [
{ name = "httpx" },
{ name = "pillow", version = "11.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10' or (extra == 'group-8-together-pydantic-v1' and extra == 'group-8-together-pydantic-v2')" },
{ name = "pillow", version = "12.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10' or (extra == 'group-8-together-pydantic-v1' and extra == 'group-8-together-pydantic-v2')" },
+ { name = "py-machineid" },
{ name = "pydantic", version = "1.10.24", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'group-8-together-pydantic-v1'" },
{ name = "pydantic", version = "2.12.5", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'group-8-together-pydantic-v2' or extra != 'group-8-together-pydantic-v1'" },
{ name = "rich" },
@@ -2033,6 +2046,7 @@ requires-dist = [
{ name = "httpx", specifier = ">=0.23.0,<1" },
{ name = "httpx-aiohttp", marker = "extra == 'aiohttp'", specifier = ">=0.1.9" },
{ name = "pillow", specifier = ">=10.4.0" },
+ { name = "py-machineid", specifier = ">=1.0.0" },
{ name = "pyarrow", marker = "extra == 'pyarrow'", specifier = ">=16.1.0" },
{ name = "pyarrow-stubs", marker = "extra == 'pyarrow'", specifier = ">=10.0.1.7" },
{ name = "pydantic", specifier = ">=1.9.0,<3" },
@@ -2192,6 +2206,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" },
]
+[[package]]
+name = "winregistry"
+version = "2.1.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/86/94/ddc339d2562267af7d25d5067874f7df8c6c19ab9dd976fa830982b1c398/winregistry-2.1.2.tar.gz", hash = "sha256:50260e1aaba4116f707f86a4e287ffcb1eeae7dc0a0883c6d1776017e693fc69", size = 9538, upload-time = "2025-10-09T09:25:07.391Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/52/dd/5a18d9fbf9a3d69b40e395d80779dfaeda77b98c946df36bf7df41ddcaa5/winregistry-2.1.2-py3-none-any.whl", hash = "sha256:e142548f56fc1fc6b83ddf88baca2e9e18cd6a266d9e00f111e54977dee768cf", size = 8507, upload-time = "2025-10-09T09:25:05.82Z" },
+]
+
[[package]]
name = "yarl"
version = "1.22.0"