Skip to content

Commit 247e995

Browse files
committed
hotfix: Lazy import dreadnode to avoid cyclic issues for now
1 parent 82cab98 commit 247e995

10 files changed

Lines changed: 71 additions & 12 deletions

File tree

docs/api/chat.mdx

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,7 @@ def __init__(
12651265
"""How to handle cache_control entries on messages."""
12661266
self.task_name: str = generator.to_identifier(short=True)
12671267
"""The name of the pipeline task, used for logging and debugging."""
1268-
self.scorers: list[dn.Scorer[Chat]] = []
1268+
self.scorers: list[Scorer[Chat]] = []
12691269
"""List of dreadnode scorers to evaluate the generated chat upon completion."""
12701270

12711271
self.until_types: list[type[Model]] = []
@@ -2124,6 +2124,8 @@ async def run(
21242124
Returns:
21252125
The generated Chat.
21262126
"""
2127+
import dreadnode as dn
2128+
21272129
if allow_failed:
21282130
warnings.warn(
21292131
"The 'allow_failed' argument is deprecated, use 'on_failed=\"include\"'.",
@@ -2270,6 +2272,8 @@ async def run_batch(
22702272
Returns:
22712273
A list of generatated Chats.
22722274
"""
2275+
import dreadnode as dn
2276+
22732277
on_failed = on_failed or self.on_failed
22742278
count, messages, params = self._fit_batch_args(many, params)
22752279

@@ -2410,6 +2414,8 @@ async def run_many(
24102414
Returns:
24112415
A list of generated Chats.
24122416
"""
2417+
import dreadnode as dn
2418+
24132419
if count < 1:
24142420
raise ValueError("Count must be greater than 0")
24152421

@@ -2536,6 +2542,8 @@ async def run_over(
25362542
Returns:
25372543
A list of generatated Chats.
25382544
"""
2545+
import dreadnode as dn
2546+
25392547
on_failed = on_failed or self.on_failed
25402548

25412549
_generators: list[Generator] = [
@@ -2608,7 +2616,7 @@ Adds one or more scorers to the pipeline to evaluate the generated chat upon com
26082616
```python
26092617
def score(
26102618
self,
2611-
*scorers: dn.Scorer[Chat] | ScorerCallable[Chat],
2619+
*scorers: "Scorer[Chat] | ScorerCallable[Chat]",
26122620
filter: "ChatFilterMode | ChatFilterFunction" = "last",
26132621
) -> "ChatPipeline":
26142622
"""
@@ -2631,6 +2639,8 @@ def score(
26312639
Returns:
26322640
The updated pipeline.
26332641
"""
2642+
import dreadnode as dn
2643+
26342644
self.scorers.extend(
26352645
[
26362646
dn.scorers.wrap_chat(

docs/api/completion.mdx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,8 @@ async def run(
886886
Returns:
887887
The generated Completion.
888888
"""
889+
import dreadnode as dn
890+
889891
if on_failed is None:
890892
on_failed = "include" if allow_failed else self.on_failed
891893

@@ -978,6 +980,8 @@ async def run_batch(
978980
Returns:
979981
A list of generatated Completions.
980982
"""
983+
import dreadnode as dn
984+
981985
on_failed = on_failed or self.on_failed
982986
params = self._fit_params(len(many), params)
983987

@@ -1061,6 +1065,8 @@ async def run_many(
10611065
Returns:
10621066
A list of generatated Completions.
10631067
"""
1068+
import dreadnode as dn
1069+
10641070
on_failed = on_failed or self.on_failed
10651071
states = self._initialize_states(count, params)
10661072

@@ -1144,6 +1150,8 @@ async def run_over(
11441150
Returns:
11451151
A list of generatated Completions.
11461152
"""
1153+
import dreadnode as dn
1154+
11471155
on_failed = on_failed or self.on_failed
11481156

11491157
_generators: list[Generator] = [

docs/api/prompt.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def bind(
264264
await say_hello.bind("gpt-3.5-turbo")("the world")
265265
~~~
266266
"""
267+
import dreadnode as dn
268+
267269
pipeline = self._resolve_to_pipeline(other)
268270
if pipeline.on_failed == "skip":
269271
raise NotImplementedError(
@@ -369,6 +371,8 @@ def bind_many(
369371
await say_hello.bind_many("gpt-4.1")(5, "the world")
370372
~~~
371373
"""
374+
import dreadnode as dn
375+
372376
pipeline = self._resolve_to_pipeline(other)
373377
if pipeline.on_failed == "include" and not isinstance(self.output, ChatOutput):
374378
raise NotImplementedError(

docs/api/tools.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ async def handle_tool_call( # noqa: PLR0912
142142
A tuple containing the message to send back to the generator and a
143143
boolean indicating whether tool calling should stop.
144144
"""
145+
import dreadnode as dn
145146

146147
from rigging.message import ContentText, ContentTypes, Message
147148

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "rigging"
3-
version = "3.2.0"
3+
version = "3.2.1"
44
description = "LLM Interaction Framework"
55
authors = ["Nick Landers <monoxgas@gmail.com>"]
66
license = "MIT"

rigging/chat.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from typing import runtime_checkable
1919
from uuid import UUID, uuid4
2020

21-
import dreadnode as dn
22-
from dreadnode.metric import ScorerCallable
2321
from pydantic import (
2422
BaseModel,
2523
ConfigDict,
@@ -60,6 +58,7 @@
6058
from rigging.util import flatten_list, get_qualified_name
6159

6260
if t.TYPE_CHECKING:
61+
from dreadnode.metric import Scorer, ScorerCallable
6362
from dreadnode.scorers.rigging import ChatFilterFunction, ChatFilterMode
6463
from elasticsearch import AsyncElasticsearch
6564

@@ -761,6 +760,8 @@ def depth(self) -> int:
761760

762761

763762
def _wrap_watch_callback(callback: WatchChatCallback) -> WatchChatCallback:
763+
import dreadnode as dn
764+
764765
callback_name = get_qualified_name(callback)
765766
return dn.task(
766767
name=f"watch - {callback_name}",
@@ -804,7 +805,7 @@ def __init__(
804805
"""How to handle cache_control entries on messages."""
805806
self.task_name: str = generator.to_identifier(short=True)
806807
"""The name of the pipeline task, used for logging and debugging."""
807-
self.scorers: list[dn.Scorer[Chat]] = []
808+
self.scorers: list[Scorer[Chat]] = []
808809
"""List of dreadnode scorers to evaluate the generated chat upon completion."""
809810

810811
self.until_types: list[type[Model]] = []
@@ -1380,7 +1381,7 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
13801381

13811382
def score(
13821383
self,
1383-
*scorers: dn.Scorer[Chat] | ScorerCallable[Chat],
1384+
*scorers: "Scorer[Chat] | ScorerCallable[Chat]",
13841385
filter: "ChatFilterMode | ChatFilterFunction" = "last",
13851386
) -> "ChatPipeline":
13861387
"""
@@ -1403,6 +1404,8 @@ def score(
14031404
Returns:
14041405
The updated pipeline.
14051406
"""
1407+
import dreadnode as dn
1408+
14061409
self.scorers.extend(
14071410
[
14081411
dn.scorers.wrap_chat(
@@ -1512,6 +1515,8 @@ async def _process_tool_call(tool_call: ToolCall) -> bool:
15121515
return next_pipeline.step()
15131516

15141517
async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
1518+
import dreadnode as dn
1519+
15151520
if chat.error: # If we have an error, we should not attempt to parse.
15161521
return None
15171522

@@ -1613,6 +1618,8 @@ async def complete() -> None:
16131618
state.chat = step.chats[-1] if step.chats else state.chat
16141619

16151620
async def _score_chats(self, chats: list[Chat]) -> None:
1621+
import dreadnode as dn
1622+
16161623
if not self.scorers:
16171624
return
16181625

@@ -1644,6 +1651,8 @@ async def _step( # noqa: PLR0915, PLR0912
16441651
params: list[GenerateParams],
16451652
on_failed: FailMode,
16461653
) -> PipelineStepGenerator:
1654+
import dreadnode as dn
1655+
16471656
chats: ChatList = ChatList([])
16481657

16491658
# Some pre-run work
@@ -2015,6 +2024,8 @@ async def run(
20152024
Returns:
20162025
The generated Chat.
20172026
"""
2027+
import dreadnode as dn
2028+
20182029
if allow_failed:
20192030
warnings.warn(
20202031
"The 'allow_failed' argument is deprecated, use 'on_failed=\"include\"'.",
@@ -2123,6 +2134,8 @@ async def run_many(
21232134
Returns:
21242135
A list of generated Chats.
21252136
"""
2137+
import dreadnode as dn
2138+
21262139
if count < 1:
21272140
raise ValueError("Count must be greater than 0")
21282141

@@ -2286,6 +2299,8 @@ async def run_batch(
22862299
Returns:
22872300
A list of generatated Chats.
22882301
"""
2302+
import dreadnode as dn
2303+
22892304
on_failed = on_failed or self.on_failed
22902305
count, messages, params = self._fit_batch_args(many, params)
22912306

@@ -2370,6 +2385,8 @@ async def run_over(
23702385
Returns:
23712386
A list of generatated Chats.
23722387
"""
2388+
import dreadnode as dn
2389+
23732390
on_failed = on_failed or self.on_failed
23742391

23752392
_generators: list[Generator] = [

rigging/completion.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing import runtime_checkable
1212
from uuid import UUID, uuid4
1313

14-
import dreadnode as dn
1514
from loguru import logger
1615
from pydantic import BaseModel, ConfigDict, Field, computed_field
1716

@@ -22,6 +21,8 @@
2221
from rigging.util import get_qualified_name
2322

2423
if t.TYPE_CHECKING:
24+
from dreadnode import Span
25+
2526
from rigging.chat import FailMode
2627
from rigging.model import Model, ModelT
2728

@@ -571,6 +572,8 @@ def _until_parse_callback(self, text: str) -> bool:
571572
return False
572573

573574
async def _watch_callback(self, completions: list[Completion]) -> None:
575+
import dreadnode as dn
576+
574577
def wrap_watch_callback(
575578
callback: WatchCompletionCallback,
576579
) -> t.Callable[[list[Completion]], t.Awaitable[None]]:
@@ -623,6 +626,8 @@ async def _post_run(
623626
completions: list[Completion],
624627
on_failed: "FailMode",
625628
) -> list[Completion]:
629+
import dreadnode as dn
630+
626631
if on_failed == "skip":
627632
completions = [c for c in completions if not c.failed]
628633

@@ -723,7 +728,7 @@ def _initialize_states(
723728

724729
async def _run( # noqa: PLR0912
725730
self,
726-
span: dn.Span,
731+
span: "Span",
727732
states: list[RunState],
728733
on_failed: "FailMode",
729734
batch_mode: bool = False, # noqa: FBT001, FBT002
@@ -822,6 +827,8 @@ async def run(
822827
Returns:
823828
The generated Completion.
824829
"""
830+
import dreadnode as dn
831+
825832
if on_failed is None:
826833
on_failed = "include" if allow_failed else self.on_failed
827834

@@ -871,6 +878,8 @@ async def run_many(
871878
Returns:
872879
A list of generatated Completions.
873880
"""
881+
import dreadnode as dn
882+
874883
on_failed = on_failed or self.on_failed
875884
states = self._initialize_states(count, params)
876885

@@ -913,6 +922,8 @@ async def run_batch(
913922
Returns:
914923
A list of generatated Completions.
915924
"""
925+
import dreadnode as dn
926+
916927
on_failed = on_failed or self.on_failed
917928
params = self._fit_params(len(many), params)
918929

@@ -960,6 +971,8 @@ async def run_over(
960971
Returns:
961972
A list of generatated Completions.
962973
"""
974+
import dreadnode as dn
975+
963976
on_failed = on_failed or self.on_failed
964977

965978
_generators: list[Generator] = [

rigging/generator/litellm_.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import re
55
import typing as t
66

7-
import dreadnode as dn
87
import litellm
98
import litellm.types.utils
109
from loguru import logger
@@ -161,6 +160,8 @@ def semaphore(self) -> asyncio.Semaphore:
161160
return self._semaphore
162161

163162
async def supports_function_calling(self) -> bool | None:
163+
import dreadnode as dn
164+
164165
if self._supports_function_calling is not None:
165166
return self._supports_function_calling
166167

rigging/prompt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import typing as t
1010
from collections import OrderedDict
1111

12-
import dreadnode as dn
1312
from jinja2 import Environment, StrictUndefined, meta
1413
from pydantic import ValidationError
1514
from typing_extensions import Concatenate, ParamSpec # noqa: UP035
@@ -553,6 +552,8 @@ def _resolve_to_pipeline(self, other: ChatPipeline | Generator | Chat | str) ->
553552
raise ValueError(f"Invalid type for binding: {type(other)}")
554553

555554
async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
555+
import dreadnode as dn
556+
556557
if self.output is None or isinstance(self.output, ChatOutput):
557558
return None
558559

@@ -836,6 +837,8 @@ def say_hello(name: str) -> str:
836837
await say_hello.bind("gpt-3.5-turbo")("the world")
837838
```
838839
"""
840+
import dreadnode as dn
841+
839842
pipeline = self._resolve_to_pipeline(other)
840843
if pipeline.on_failed == "skip":
841844
raise NotImplementedError(
@@ -901,6 +904,8 @@ def say_hello(name: str) -> str:
901904
await say_hello.bind_many("gpt-4.1")(5, "the world")
902905
```
903906
"""
907+
import dreadnode as dn
908+
904909
pipeline = self._resolve_to_pipeline(other)
905910
if pipeline.on_failed == "include" and not isinstance(self.output, ChatOutput):
906911
raise NotImplementedError(

rigging/tools/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import warnings
1111
from functools import cached_property
1212

13-
import dreadnode as dn
1413
import typing_extensions as te
1514
from pydantic import (
1615
BaseModel,
@@ -355,6 +354,7 @@ async def handle_tool_call( # noqa: PLR0912
355354
A tuple containing the message to send back to the generator and a
356355
boolean indicating whether tool calling should stop.
357356
"""
357+
import dreadnode as dn
358358

359359
from rigging.message import ContentText, ContentTypes, Message
360360

0 commit comments

Comments
 (0)