1818from typing import (
1919 TYPE_CHECKING ,
2020 ClassVar ,
21+ Dict ,
2122 Optional ,
23+ Union ,
2224)
2325
2426from datacustomcode .config import config
2527from datacustomcode .file .path .default import DefaultFindFilePath
2628from datacustomcode .io .reader .base import BaseDataCloudReader
29+ from datacustomcode .llm_gateway_config import spark_llm_gateway_config
2730from datacustomcode .spark .default import DefaultSparkSessionProvider
2831
2932if TYPE_CHECKING :
3033 from pathlib import Path
3134
32- from pyspark .sql import DataFrame as PySparkDataFrame
35+ from pyspark .sql import Column , DataFrame as PySparkDataFrame
3336
3437 from datacustomcode .io .reader .base import BaseDataCloudReader
3538 from datacustomcode .io .writer .base import BaseDataCloudWriter , WriteMode
39+ from datacustomcode .llm_gateway .spark_base import SparkLLMGateway
3640 from datacustomcode .spark .base import BaseSparkSessionProvider
3741
3842
43+ def _build_spark_llm_gateway () -> "SparkLLMGateway" :
44+ """Instantiate the SDK-configured :class:`SparkLLMGateway`.
45+
46+ Raises:
47+ RuntimeError: If no ``spark_llm_gateway_config`` has been loaded.
48+ """
49+ cfg = spark_llm_gateway_config .spark_llm_gateway_config
50+ if cfg is None :
51+ raise RuntimeError (
52+ "spark_llm_gateway_config is not configured. Add a "
53+ "'spark_llm_gateway_config' section to config.yaml."
54+ )
55+ return cfg .to_object ()
56+
57+
58+ def llm_gateway_generate_text_col (
59+ template : str ,
60+ values : Union [Dict [str , "Column" ], "Column" ],
61+ model_id : Optional [str ] = None ,
62+ ) -> "Column" :
63+ """Build a Spark Column that runs the LLM Gateway per row.
64+
65+ The returned Column yields a struct ``{status, response, error_code,
66+ error_message}`` for each row. Use ``[...]`` (or ``getField``) to pick the
67+ field you want, e.g. ``llm_gateway_generate_text_col(...)["response"]``.
68+ Per-row failures populate ``status`` / ``error_code`` / ``error_message``
69+ so a single bad row does not abort the whole Spark job.
70+
71+ Example:
72+
73+ >>> result = llm_gateway_generate_text_col(
74+ ... "In one sentence, greet {name} from {city}.",
75+ ... {"name": col("name__c"), "city": col("homecity__c")},
76+ ... model_id="sfdc_ai__DefaultGPT4Omni",
77+ ... )
78+ >>> df.withColumn("greeting__c", result["response"])
79+ >>> # …or keep the struct around and inspect failures:
80+ >>> df.withColumn("llm", result).select(
81+ ... "llm.status", "llm.response", "llm.error_message"
82+ ... )
83+
84+ Args:
85+ template: The prompt template, with ``{field}`` placeholders matching
86+ keys in ``values``. Substitution uses ``str.format``.
87+ values: Either a mapping from placeholder name to Spark ``Column``, or
88+ a single ``Column`` whose value is already a struct.
89+ model_id: LLM model id. Defaults to ``sfdc_ai__DefaultGPT4Omni``.
90+
91+ Returns:
92+ A Spark ``Column`` of ``StructType`` with fields ``status``,
93+ ``response``, ``error_code``, and ``error_message`` (all nullable
94+ strings). On success, ``status == "SUCCESS"`` and ``response`` holds
95+ the generated text; on failure, ``status == "ERROR"`` and the
96+ ``error_*`` fields carry diagnostic detail.
97+ """
98+ gateway = Client ()._get_spark_llm_gateway ()
99+ return gateway .llm_gateway_generate_text_col (template , values , model_id = model_id )
100+
101+
39102class DataCloudObjectType (Enum ):
40103 DLO = "dlo"
41104 DMO = "dmo"
@@ -94,18 +157,21 @@ class Client:
94157 finder: Find a file path
95158 reader: A custom reader to use for reading Data Cloud objects.
96159 writer: A custom writer to use for writing Data Cloud objects.
160+ spark_llm_gateway: Optional custom :class:`SparkLLMGateway`.
97161
98162 Example:
99163 >>> client = Client()
100164 >>> file_path = client.find_file_path("data.csv")
101165 >>> dlo = client.read_dlo("my_dlo")
102166 >>> client.write_to_dmo("my_dmo", dlo)
167+ >>> answer = client.llm_gateway_generate_text("Generate a greeting message")
103168 """
104169
105170 _instance : ClassVar [Optional [Client ]] = None
106171 _reader : BaseDataCloudReader
107172 _writer : BaseDataCloudWriter
108173 _file : DefaultFindFilePath
174+ _spark_llm_gateway : Optional [SparkLLMGateway ]
109175 _data_layer_history : dict [DataCloudObjectType , set [str ]]
110176 _code_type : str
111177
@@ -114,11 +180,13 @@ def __new__(
114180 reader : Optional [BaseDataCloudReader ] = None ,
115181 writer : Optional [BaseDataCloudWriter ] = None ,
116182 spark_provider : Optional [BaseSparkSessionProvider ] = None ,
183+ spark_llm_gateway : Optional [SparkLLMGateway ] = None ,
117184 code_type : str = "script" ,
118185 ) -> Client :
119186
120187 if cls ._instance is None :
121188 cls ._instance = super ().__new__ (cls )
189+ cls ._instance ._spark_llm_gateway = spark_llm_gateway
122190 # Initialize Readers and Writers from config
123191 # and/or provided reader and writer
124192 if reader is None or writer is None :
@@ -225,6 +293,41 @@ def find_file_path(self, file_name: str) -> Path:
225293
226294 return self ._file .find_file_path (file_name ) # type: ignore[no-any-return]
227295
296+ def llm_gateway_generate_text (
297+ self ,
298+ prompt : str ,
299+ model_id : Optional [str ] = None ,
300+ ) -> str :
301+ """Issue a one-shot LLM Gateway call. This is the scalar counterpart to
302+ :func:`llm_gateway_generate_text_col`: it runs **once** — not per row.
303+ Use the column helper method instead when you want to fan a prompt out across
304+ every row of a DataFrame.
305+
306+ Example:
307+
308+ >>> response = Client().llm_gateway_generate_text(
309+ ... "Generate a greeting message"
310+ ... )
311+
312+ Args:
313+ prompt: The literal prompt to send. Plain text — no
314+ ``{field}`` substitution is performed on this string.
315+ model_id: LLM model id to target. Defaults to
316+ ``sfdc_ai__DefaultGPT4Omni`` when ``None``.
317+
318+ Returns:
319+ The generated text as a plain Python ``str``; empty when the
320+ gateway response carries no generated text.
321+ """
322+ return self ._get_spark_llm_gateway ().llm_gateway_generate_text (
323+ prompt , model_id = model_id
324+ )
325+
326+ def _get_spark_llm_gateway (self ) -> SparkLLMGateway :
327+ if self ._spark_llm_gateway is None :
328+ self ._spark_llm_gateway = _build_spark_llm_gateway ()
329+ return self ._spark_llm_gateway
330+
228331 def _validate_data_layer_history_does_not_contain (
229332 self , data_cloud_object_type : DataCloudObjectType
230333 ) -> None :
0 commit comments