Skip to content

Commit 8c2d301

Browse files
Merge pull request #105 from forcedotcom/callout
support llm gateway callout for script
2 parents 065878d + 3865b60 commit 8c2d301

21 files changed

Lines changed: 921 additions & 85 deletions

File tree

README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,55 @@ Options:
305305
- `--function-invoke-opt TEXT`: Currently we support only `UnstructuredChunking` for functions.
306306

307307

308+
## Testing LLM Gateway
309+
310+
You can use AI models configured in Salesforce to generate responses while transforming your data. Below is a sample code example:
311+
312+
```
313+
from datacustomcode.client import Client, llm_gateway_generate_text_col
314+
315+
316+
def main():
317+
client = Client()
318+
df = client.read_dlo("Input__dll")
319+
# llm_gateway_generate_text_col returns a struct
320+
# {status, response, error_code, error_message} per row, so per-row
321+
# failures don't abort the Spark job. Pick the field you want with [].
322+
df_generated = df.withColumn(
323+
"greeting__c",
324+
llm_gateway_generate_text_col(
325+
"In one sentence, greet {name} from {city}.",
326+
{"name": col("name__c"), "city": col("homecity__c")},
327+
model_id="sfdc_ai__DefaultGPT4Omni", # An AI model in your org
328+
)["response"],
329+
)
330+
331+
dlo_name = "Output_dll"
332+
client.write_to_dlo(dlo_name, df_upper1, write_mode=WriteMode.APPEND)
333+
334+
greeting = client.llm_gateway_generate_text("In one sentence, generate a greeting message", "sfdc_ai__DefaultGPT52")
335+
336+
if __name__ == "__main__":
337+
main()
338+
```
339+
340+
In order to test this code on your local machine before deploying it to Data Cloud, you must first set up an External Client App that allows access to the Agent API. Follow this guide to create the ECA https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api-get-started.html#create-a-salesforce-app. You must use `http://localhost:1717/OauthRedirect` as the callback URL.
341+
342+
Once the ECA is set up, log in to your org using this ECA
343+
```
344+
sf org login web \
345+
--alias myorg \
346+
--instance-url https://{MY_DOMAIN_URL} \
347+
--client-id {CONSUMER_KEY} \
348+
--scopes "sfap_api api"
349+
```
350+
351+
then you can test your code using `myorg` alias
352+
```
353+
datacustomcode run ./payload/entrypoint.py --sf-cli-org myorg
354+
```
355+
356+
308357
## Docker usage
309358

310359
The SDK provides Docker-based development options that allow you to test your code in an environment that closely resembles Data Cloud's execution environment.

src/datacustomcode/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
"AuthType",
1818
"Client",
1919
"Credentials",
20+
"DefaultSparkLLMGateway",
2021
"PrintDataCloudWriter",
2122
"QueryAPIDataCloudReader",
23+
"SparkLLMGateway",
24+
"llm_gateway_generate_text_col",
2225
]
2326

2427

@@ -44,4 +47,16 @@ def __getattr__(name: str):
4447
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
4548

4649
return QueryAPIDataCloudReader
50+
elif name == "SparkLLMGateway":
51+
from datacustomcode.llm_gateway import SparkLLMGateway
52+
53+
return SparkLLMGateway
54+
elif name == "DefaultSparkLLMGateway":
55+
from datacustomcode.llm_gateway import DefaultSparkLLMGateway
56+
57+
return DefaultSparkLLMGateway
58+
elif name == "llm_gateway_generate_text_col":
59+
from datacustomcode.client import llm_gateway_generate_text_col
60+
61+
return llm_gateway_generate_text_col
4762
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

src/datacustomcode/client.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,87 @@
1818
from typing import (
1919
TYPE_CHECKING,
2020
ClassVar,
21+
Dict,
2122
Optional,
23+
Union,
2224
)
2325

2426
from datacustomcode.config import config
2527
from datacustomcode.file.path.default import DefaultFindFilePath
2628
from datacustomcode.io.reader.base import BaseDataCloudReader
29+
from datacustomcode.llm_gateway_config import spark_llm_gateway_config
2730
from datacustomcode.spark.default import DefaultSparkSessionProvider
2831

2932
if TYPE_CHECKING:
3033
from pathlib import Path
3134

32-
from pyspark.sql import DataFrame as PySparkDataFrame
35+
from pyspark.sql import Column, DataFrame as PySparkDataFrame
3336

3437
from datacustomcode.io.reader.base import BaseDataCloudReader
3538
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
39+
from datacustomcode.llm_gateway.spark_base import SparkLLMGateway
3640
from datacustomcode.spark.base import BaseSparkSessionProvider
3741

3842

43+
def _build_spark_llm_gateway() -> "SparkLLMGateway":
44+
"""Instantiate the SDK-configured :class:`SparkLLMGateway`.
45+
46+
Raises:
47+
RuntimeError: If no ``spark_llm_gateway_config`` has been loaded.
48+
"""
49+
cfg = spark_llm_gateway_config.spark_llm_gateway_config
50+
if cfg is None:
51+
raise RuntimeError(
52+
"spark_llm_gateway_config is not configured. Add a "
53+
"'spark_llm_gateway_config' section to config.yaml."
54+
)
55+
return cfg.to_object()
56+
57+
58+
def llm_gateway_generate_text_col(
59+
template: str,
60+
values: Union[Dict[str, "Column"], "Column"],
61+
model_id: Optional[str] = None,
62+
) -> "Column":
63+
"""Build a Spark Column that runs the LLM Gateway per row.
64+
65+
The returned Column yields a struct ``{status, response, error_code,
66+
error_message}`` for each row. Use ``[...]`` (or ``getField``) to pick the
67+
field you want, e.g. ``llm_gateway_generate_text_col(...)["response"]``.
68+
Per-row failures populate ``status`` / ``error_code`` / ``error_message``
69+
so a single bad row does not abort the whole Spark job.
70+
71+
Example:
72+
73+
>>> result = llm_gateway_generate_text_col(
74+
... "In one sentence, greet {name} from {city}.",
75+
... {"name": col("name__c"), "city": col("homecity__c")},
76+
... model_id="sfdc_ai__DefaultGPT4Omni",
77+
... )
78+
>>> df.withColumn("greeting__c", result["response"])
79+
>>> # …or keep the struct around and inspect failures:
80+
>>> df.withColumn("llm", result).select(
81+
... "llm.status", "llm.response", "llm.error_message"
82+
... )
83+
84+
Args:
85+
template: The prompt template, with ``{field}`` placeholders matching
86+
keys in ``values``. Substitution uses ``str.format``.
87+
values: Either a mapping from placeholder name to Spark ``Column``, or
88+
a single ``Column`` whose value is already a struct.
89+
model_id: LLM model id. Defaults to ``sfdc_ai__DefaultGPT4Omni``.
90+
91+
Returns:
92+
A Spark ``Column`` of ``StructType`` with fields ``status``,
93+
``response``, ``error_code``, and ``error_message`` (all nullable
94+
strings). On success, ``status == "SUCCESS"`` and ``response`` holds
95+
the generated text; on failure, ``status == "ERROR"`` and the
96+
``error_*`` fields carry diagnostic detail.
97+
"""
98+
gateway = Client()._get_spark_llm_gateway()
99+
return gateway.llm_gateway_generate_text_col(template, values, model_id=model_id)
100+
101+
39102
class DataCloudObjectType(Enum):
40103
DLO = "dlo"
41104
DMO = "dmo"
@@ -94,18 +157,21 @@ class Client:
94157
finder: Find a file path
95158
reader: A custom reader to use for reading Data Cloud objects.
96159
writer: A custom writer to use for writing Data Cloud objects.
160+
spark_llm_gateway: Optional custom :class:`SparkLLMGateway`.
97161
98162
Example:
99163
>>> client = Client()
100164
>>> file_path = client.find_file_path("data.csv")
101165
>>> dlo = client.read_dlo("my_dlo")
102166
>>> client.write_to_dmo("my_dmo", dlo)
167+
>>> answer = client.llm_gateway_generate_text("Generate a greeting message")
103168
"""
104169

105170
_instance: ClassVar[Optional[Client]] = None
106171
_reader: BaseDataCloudReader
107172
_writer: BaseDataCloudWriter
108173
_file: DefaultFindFilePath
174+
_spark_llm_gateway: Optional[SparkLLMGateway]
109175
_data_layer_history: dict[DataCloudObjectType, set[str]]
110176
_code_type: str
111177

@@ -114,11 +180,13 @@ def __new__(
114180
reader: Optional[BaseDataCloudReader] = None,
115181
writer: Optional[BaseDataCloudWriter] = None,
116182
spark_provider: Optional[BaseSparkSessionProvider] = None,
183+
spark_llm_gateway: Optional[SparkLLMGateway] = None,
117184
code_type: str = "script",
118185
) -> Client:
119186

120187
if cls._instance is None:
121188
cls._instance = super().__new__(cls)
189+
cls._instance._spark_llm_gateway = spark_llm_gateway
122190
# Initialize Readers and Writers from config
123191
# and/or provided reader and writer
124192
if reader is None or writer is None:
@@ -225,6 +293,41 @@ def find_file_path(self, file_name: str) -> Path:
225293

226294
return self._file.find_file_path(file_name) # type: ignore[no-any-return]
227295

296+
def llm_gateway_generate_text(
297+
self,
298+
prompt: str,
299+
model_id: Optional[str] = None,
300+
) -> str:
301+
"""Issue a one-shot LLM Gateway call. This is the scalar counterpart to
302+
:func:`llm_gateway_generate_text_col`: it runs **once** — not per row.
303+
Use the column helper method instead when you want to fan a prompt out across
304+
every row of a DataFrame.
305+
306+
Example:
307+
308+
>>> response = Client().llm_gateway_generate_text(
309+
... "Generate a greeting message"
310+
... )
311+
312+
Args:
313+
prompt: The literal prompt to send. Plain text — no
314+
``{field}`` substitution is performed on this string.
315+
model_id: LLM model id to target. Defaults to
316+
``sfdc_ai__DefaultGPT4Omni`` when ``None``.
317+
318+
Returns:
319+
The generated text as a plain Python ``str``; empty when the
320+
gateway response carries no generated text.
321+
"""
322+
return self._get_spark_llm_gateway().llm_gateway_generate_text(
323+
prompt, model_id=model_id
324+
)
325+
326+
def _get_spark_llm_gateway(self) -> SparkLLMGateway:
327+
if self._spark_llm_gateway is None:
328+
self._spark_llm_gateway = _build_spark_llm_gateway()
329+
return self._spark_llm_gateway
330+
228331
def _validate_data_layer_history_does_not_contain(
229332
self, data_cloud_object_type: DataCloudObjectType
230333
) -> None:

src/datacustomcode/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ llm_gateway_config:
2828
type_config_name: DefaultLLMGateway
2929
options:
3030
credentials_profile: default
31+
32+
spark_llm_gateway_config:
33+
type_config_name: DefaultSparkLLMGateway

src/datacustomcode/einstein_platform_config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,23 @@
1515

1616
from typing import (
1717
ClassVar,
18+
Generic,
1819
Optional,
1920
Type,
20-
cast,
21+
TypeVar,
2122
)
2223

2324
from datacustomcode.common_config import BaseObjectConfig
2425

26+
_T = TypeVar("_T")
2527

26-
class CredentialsObjectConfig(BaseObjectConfig):
28+
29+
class CredentialsObjectConfig(BaseObjectConfig, Generic[_T]):
2730
type_to_create: ClassVar[Type]
2831
credentials_profile: Optional[str] = None
2932
sf_cli_org: Optional[str] = None
3033

31-
def to_object(self):
34+
def to_object(self) -> _T:
3235
"""Create an object instance, automatically including credentials in options"""
3336

3437
options = self.options.copy()
@@ -38,4 +41,5 @@ def to_object(self):
3841
options["sf_cli_org"] = self.sf_cli_org
3942

4043
type_ = self.type_to_create.subclass_from_config_name(self.type_config_name)
41-
return cast(type_, type_(**options))
44+
instance: _T = type_(**options)
45+
return instance

src/datacustomcode/einstein_predictions_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
_E = TypeVar("_E", bound=EinsteinPredictions)
2929

3030

31-
class EinsteinPredictionsObjectConfig(CredentialsObjectConfig, Generic[_E]):
31+
class EinsteinPredictionsObjectConfig(CredentialsObjectConfig[_E], Generic[_E]):
3232
type_to_create: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract]
3333

3434

src/datacustomcode/function/feature_types/chunking.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717
Pydantic models for Search Index Chunking V1
1818
"""
19+
1920
from enum import Enum
2021
from typing import (
2122
Dict,

src/datacustomcode/llm_gateway/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@
1515

1616
from datacustomcode.llm_gateway.base import LLMGateway
1717
from datacustomcode.llm_gateway.default import DefaultLLMGateway
18+
from datacustomcode.llm_gateway.errors import LLMGatewayCallError
19+
from datacustomcode.llm_gateway.spark_base import SparkLLMGateway
20+
from datacustomcode.llm_gateway.spark_default import DefaultSparkLLMGateway
1821

1922
__all__ = [
2023
"DefaultLLMGateway",
24+
"DefaultSparkLLMGateway",
2125
"LLMGateway",
26+
"LLMGatewayCallError",
27+
"SparkLLMGateway",
2228
]
Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,25 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
"""Exceptions raised by LLM Gateway implementations."""
16+
1517
from __future__ import annotations
1618

17-
from abc import ABC
19+
from typing import Optional
1820

19-
from datacustomcode.mixin import UserExtendableNamedConfigMixin
2021

22+
class LLMGatewayCallError(RuntimeError):
23+
"""Raised when an LLM Gateway call returns an error."""
2124

22-
class BaseProxyAccessLayer(ABC, UserExtendableNamedConfigMixin):
23-
def __init__(self):
24-
pass
25+
def __init__(
26+
self,
27+
message: str,
28+
*,
29+
status: Optional[object] = None,
30+
error_code: Optional[str] = None,
31+
error_message: Optional[str] = None,
32+
) -> None:
33+
super().__init__(message)
34+
self.status = status
35+
self.error_code = error_code
36+
self.error_message = error_message

0 commit comments

Comments
 (0)