77import os
88import threading
99from datetime import datetime
10- from typing import Dict , Any , List , Optional
10+ from typing import Dict , Any , List , Optional , Callable
11+
12+ import httpx
1113
1214from cozeloop .client import Client
1315from cozeloop ._noop import NOOP_SPAN , _NoopClient
1719from cozeloop .internal .httpclient import Auth
1820from cozeloop .internal .prompt import PromptProvider
1921from cozeloop .internal .trace import TraceProvider
22+ from cozeloop .internal .trace .model .model import FinishEventInfo , TagTruncateConf
23+ from cozeloop .internal .trace .trace import default_finish_event_processor
2024from cozeloop .span import SpanContext , Span
2125
2226logger = logging .getLogger (__name__ )
3539_default_client = None
3640_client_lock = threading .Lock ()
3741
42+ class APIBasePath :
43+ def __init__ (
44+ self ,
45+ trace_span_upload_path : str ,
46+ trace_file_upload_path : str ,
47+ ):
48+ self .trace_span_upload_path = trace_span_upload_path
49+ self .trace_file_upload_path = trace_file_upload_path
50+
3851
3952def _generate_cache_key (* args ) -> str :
4053 key_str = "\t " .join (str (arg ) for arg in args )
@@ -54,8 +67,12 @@ def new_client(
5467 prompt_cache_max_count : int = consts .DEFAULT_PROMPT_CACHE_MAX_COUNT ,
5568 prompt_cache_refresh_interval : int = consts .DEFAULT_PROMPT_CACHE_REFRESH_INTERVAL ,
5669 prompt_trace : bool = False ,
70+ http_client : Optional [httpx .Client ] = None ,
71+ trace_finish_event_processor : Optional [Callable [[FinishEventInfo ], None ]] = None ,
72+ tag_truncate_conf : Optional [TagTruncateConf ] = None ,
73+ api_base_path : Optional [APIBasePath ] = None ,
5774) -> Client :
58- cache_key = _generate_cache_key (
75+ cache_key = _generate_cache_key ( # all args are used to generate cache key
5976 api_base_url ,
6077 workspace_id ,
6178 api_token ,
@@ -67,7 +84,11 @@ def new_client(
6784 ultra_large_report ,
6885 prompt_cache_max_count ,
6986 prompt_cache_refresh_interval ,
70- prompt_trace
87+ prompt_trace ,
88+ http_client ,
89+ trace_finish_event_processor ,
90+ tag_truncate_conf ,
91+ api_base_path ,
7192 )
7293
7394 with _cache_lock :
@@ -88,6 +109,10 @@ def new_client(
88109 prompt_cache_max_count = prompt_cache_max_count ,
89110 prompt_cache_refresh_interval = prompt_cache_refresh_interval ,
90111 prompt_trace = prompt_trace ,
112+ arg_http_client = http_client ,
113+ trace_finish_event_processor = trace_finish_event_processor ,
114+ tag_truncate_conf = tag_truncate_conf ,
115+ api_base_path = api_base_path ,
91116 )
92117 _client_cache [cache_key ] = client
93118 return client
@@ -113,7 +138,11 @@ def __init__(
113138 ultra_large_report : bool = False ,
114139 prompt_cache_max_count : int = consts .DEFAULT_PROMPT_CACHE_MAX_COUNT ,
115140 prompt_cache_refresh_interval : int = consts .DEFAULT_PROMPT_CACHE_REFRESH_INTERVAL ,
116- prompt_trace : bool = False
141+ prompt_trace : bool = False ,
142+ arg_http_client : Optional [httpx .Client ] = None ,
143+ trace_finish_event_processor : Optional [Callable [[FinishEventInfo ], None ]] = None ,
144+ tag_truncate_conf : Optional [TagTruncateConf ] = None ,
145+ api_base_path : Optional [APIBasePath ] = None ,
117146 ):
118147 workspace_id = self ._get_from_env (workspace_id , ENV_WORKSPACE_ID )
119148 api_base_url = self ._get_from_env (api_base_url , ENV_API_BASE_URL )
@@ -136,6 +165,8 @@ def __init__(
136165
137166 self ._workspace_id = workspace_id
138167 inner_client = httpclient .HTTPClient ()
168+ if arg_http_client :
169+ inner_client = arg_http_client
139170 auth = self ._build_auth (
140171 api_base_url = api_base_url ,
141172 http_client = inner_client ,
@@ -151,10 +182,25 @@ def __init__(
151182 timeout = timeout ,
152183 upload_timeout = upload_timeout
153184 )
185+ finish_pro = default_finish_event_processor
186+ if trace_finish_event_processor :
187+ def combined_processor (event_info : FinishEventInfo ):
188+ default_finish_event_processor (event_info )
189+ trace_finish_event_processor (event_info )
190+ finish_pro = combined_processor
191+ span_upload_path = None
192+ file_upload_path = None
193+ if api_base_path :
194+ span_upload_path = api_base_path .trace_span_upload_path
195+ file_upload_path = api_base_path .trace_file_upload_path
154196 self ._trace_provider = TraceProvider (
155197 http_client = http_client ,
156198 workspace_id = workspace_id ,
157- ultra_large_report = ultra_large_report
199+ ultra_large_report = ultra_large_report ,
200+ finish_event_processor = finish_pro ,
201+ tag_truncate_conf = tag_truncate_conf ,
202+ span_upload_path = span_upload_path ,
203+ file_upload_path = file_upload_path ,
158204 )
159205 self ._prompt_provider = PromptProvider (
160206 workspace_id = workspace_id ,
@@ -234,7 +280,7 @@ def start_span(
234280 else :
235281 return self ._trace_provider .start_span (name = name , span_type = span_type , start_time = start_time ,
236282 parent_span_id = child_of .span_id , trace_id = child_of .trace_id ,
237- baggage = child_of .baggage , start_new_trace = start_new_trace )
283+ baggage = child_of .baggage () , start_new_trace = start_new_trace )
238284 except Exception as e :
239285 logger .warning (f"Start span failed, returning noop span. Error: { e } " )
240286 return NOOP_SPAN
0 commit comments