1818from typing import runtime_checkable
1919from uuid import UUID , uuid4
2020
21- import dreadnode as dn
22- from dreadnode .metric import ScorerCallable
2321from pydantic import (
2422 BaseModel ,
2523 ConfigDict ,
6058from rigging .util import flatten_list , get_qualified_name
6159
6260if t .TYPE_CHECKING :
61+ from dreadnode .metric import Scorer , ScorerCallable
6362 from dreadnode .scorers .rigging import ChatFilterFunction , ChatFilterMode
6463 from elasticsearch import AsyncElasticsearch
6564
@@ -761,6 +760,8 @@ def depth(self) -> int:
761760
762761
763762def _wrap_watch_callback (callback : WatchChatCallback ) -> WatchChatCallback :
763+ import dreadnode as dn
764+
764765 callback_name = get_qualified_name (callback )
765766 return dn .task (
766767 name = f"watch - { callback_name } " ,
@@ -804,7 +805,7 @@ def __init__(
804805 """How to handle cache_control entries on messages."""
805806 self .task_name : str = generator .to_identifier (short = True )
806807 """The name of the pipeline task, used for logging and debugging."""
807- self .scorers : list [dn . Scorer [Chat ]] = []
808+ self .scorers : list [Scorer [Chat ]] = []
808809 """List of dreadnode scorers to evaluate the generated chat upon completion."""
809810
810811 self .until_types : list [type [Model ]] = []
@@ -1380,7 +1381,7 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
13801381
13811382 def score (
13821383 self ,
1383- * scorers : dn . Scorer [Chat ] | ScorerCallable [Chat ],
1384+ * scorers : " Scorer[Chat] | ScorerCallable[Chat]" ,
13841385 filter : "ChatFilterMode | ChatFilterFunction" = "last" ,
13851386 ) -> "ChatPipeline" :
13861387 """
@@ -1403,6 +1404,8 @@ def score(
14031404 Returns:
14041405 The updated pipeline.
14051406 """
1407+ import dreadnode as dn
1408+
14061409 self .scorers .extend (
14071410 [
14081411 dn .scorers .wrap_chat (
@@ -1512,6 +1515,8 @@ async def _process_tool_call(tool_call: ToolCall) -> bool:
15121515 return next_pipeline .step ()
15131516
15141517 async def _then_parse (self , chat : Chat ) -> PipelineStepContextManager | None :
1518+ import dreadnode as dn
1519+
15151520 if chat .error : # If we have an error, we should not attempt to parse.
15161521 return None
15171522
@@ -1613,6 +1618,8 @@ async def complete() -> None:
16131618 state .chat = step .chats [- 1 ] if step .chats else state .chat
16141619
16151620 async def _score_chats (self , chats : list [Chat ]) -> None :
1621+ import dreadnode as dn
1622+
16161623 if not self .scorers :
16171624 return
16181625
@@ -1644,6 +1651,8 @@ async def _step( # noqa: PLR0915, PLR0912
16441651 params : list [GenerateParams ],
16451652 on_failed : FailMode ,
16461653 ) -> PipelineStepGenerator :
1654+ import dreadnode as dn
1655+
16471656 chats : ChatList = ChatList ([])
16481657
16491658 # Some pre-run work
@@ -2015,6 +2024,8 @@ async def run(
20152024 Returns:
20162025 The generated Chat.
20172026 """
2027+ import dreadnode as dn
2028+
20182029 if allow_failed :
20192030 warnings .warn (
20202031 "The 'allow_failed' argument is deprecated, use 'on_failed=\" include\" '." ,
@@ -2123,6 +2134,8 @@ async def run_many(
21232134 Returns:
21242135 A list of generated Chats.
21252136 """
2137+ import dreadnode as dn
2138+
21262139 if count < 1 :
21272140 raise ValueError ("Count must be greater than 0" )
21282141
@@ -2286,6 +2299,8 @@ async def run_batch(
22862299 Returns:
22872300 A list of generatated Chats.
22882301 """
2302+ import dreadnode as dn
2303+
22892304 on_failed = on_failed or self .on_failed
22902305 count , messages , params = self ._fit_batch_args (many , params )
22912306
@@ -2370,6 +2385,8 @@ async def run_over(
23702385 Returns:
23712386 A list of generatated Chats.
23722387 """
2388+ import dreadnode as dn
2389+
23732390 on_failed = on_failed or self .on_failed
23742391
23752392 _generators : list [Generator ] = [
0 commit comments