Skip to content

Commit ec65164

Browse files
authored
feat: Update Tinker renderers (#613)
1 parent 806ca83 commit ec65164

File tree

16 files changed

+2254
-155
lines changed

16 files changed

+2254
-155
lines changed

src/art/tinker/cookbook_v/image_processing_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ def get_image_processor(model_name: str) -> ImageProcessor:
2828

2929
from transformers.models.auto.image_processing_auto import AutoImageProcessor
3030

31-
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
32-
assert processor.is_fast, f"Could not load fast image processor for {model_name}"
31+
kwargs: dict[str, Any] = {}
32+
if model_name == "moonshotai/Kimi-K2.5":
33+
kwargs["trust_remote_code"] = True
34+
kwargs["revision"] = "3367c8d1c68584429fab7faf845a32d5195b6ac1"
35+
36+
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True, **kwargs)
3337
return processor
3438

3539

src/art/tinker/cookbook_v/renderers/__init__.py

Lines changed: 130 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
python -m tinker_cookbook.supervised.viz_sft_dataset dataset_path=Tulu3Builder renderer_name=role_colon
66
"""
77

8+
from collections.abc import Callable
9+
from typing import Any
10+
811
from ..image_processing_utils import ImageProcessor
912
from ..tokenizer_utils import Tokenizer
1013

@@ -14,15 +17,21 @@
1417
ContentPart,
1518
ImagePart,
1619
Message,
20+
# Streaming types
21+
MessageDelta,
1722
# Renderer base
1823
RenderContext,
1924
Renderer,
2025
Role,
26+
StreamingMessageHeader,
27+
StreamingTextDelta,
28+
StreamingThinkingDelta,
2129
TextPart,
2230
ThinkingPart,
2331
ToolCall,
2432
ToolSpec,
2533
TrainOnWhat,
34+
Utf8TokenDecoder,
2635
# Utility functions
2736
ensure_text,
2837
format_content_as_string,
@@ -35,9 +44,59 @@
3544
from .gpt_oss import GptOssRenderer
3645
from .qwen3 import Qwen3Renderer
3746

47+
# Global registry for custom renderer factories
48+
_CUSTOM_RENDERER_REGISTRY: dict[str, Callable[[Tokenizer, Any], Renderer]] = {}
49+
50+
51+
def register_renderer(
52+
name: str,
53+
factory: Callable[[Tokenizer, Any], Renderer],
54+
) -> None:
55+
"""Register a custom renderer factory.
56+
57+
Args:
58+
name: The renderer name
59+
factory: A callable that takes (tokenizer, image_processor) and returns a Renderer.
60+
61+
Example:
62+
def my_renderer_factory(tokenizer, image_processor=None):
63+
return MyCustomRenderer(tokenizer)
64+
65+
register_renderer("Foo/foo_renderer", my_renderer_factory)
66+
"""
67+
_CUSTOM_RENDERER_REGISTRY[name] = factory
68+
69+
70+
def get_registered_renderer_names() -> list[str]:
71+
"""Return a list of all registered custom renderer names."""
72+
return list(_CUSTOM_RENDERER_REGISTRY.keys())
73+
74+
75+
def is_renderer_registered(name: str) -> bool:
76+
"""Check if a renderer name is registered."""
77+
return name in _CUSTOM_RENDERER_REGISTRY
78+
79+
80+
def unregister_renderer(name: str) -> bool:
81+
"""Unregister a custom renderer factory.
82+
83+
Args:
84+
name: The renderer name to unregister.
85+
86+
Returns:
87+
True if the renderer was unregistered, False if it wasn't registered.
88+
"""
89+
if name in _CUSTOM_RENDERER_REGISTRY:
90+
del _CUSTOM_RENDERER_REGISTRY[name]
91+
return True
92+
return False
93+
3894

3995
def get_renderer(
40-
name: str, tokenizer: Tokenizer, image_processor: ImageProcessor | None = None
96+
name: str,
97+
tokenizer: Tokenizer,
98+
image_processor: ImageProcessor | None = None,
99+
model_name: str | None = None,
41100
) -> Renderer:
42101
"""Factory function to create renderers by name.
43102
@@ -50,16 +109,24 @@ def get_renderer(
50109
- "qwen3_vl_instruct": Qwen3 vision-language instruct (no thinking)
51110
- "qwen3_disable_thinking": Qwen3 with thinking disabled
52111
- "qwen3_instruct": Qwen3 instruct 2507 (no thinking)
112+
- "qwen3_5": Qwen3.5 VL with thinking
113+
- "qwen3_5_disable_thinking": Qwen3.5 VL with thinking disabled
53114
- "deepseekv3": DeepSeek V3 (defaults to non-thinking mode)
54115
- "deepseekv3_disable_thinking": DeepSeek V3 non-thinking (alias)
55116
- "deepseekv3_thinking": DeepSeek V3 thinking mode
56117
- "kimi_k2": Kimi K2 Thinking format
118+
- "kimi_k25": Kimi K2.5 with thinking enabled
119+
- "kimi_k25_disable_thinking": Kimi K2.5 with thinking disabled
57120
- "gpt_oss_no_sysprompt": GPT-OSS without system prompt
58121
- "gpt_oss_low_reasoning": GPT-OSS with low reasoning
59122
- "gpt_oss_medium_reasoning": GPT-OSS with medium reasoning
60123
- "gpt_oss_high_reasoning": GPT-OSS with high reasoning
124+
- Custom renderers registered via register_renderer()
61125
tokenizer: The tokenizer to use.
62126
image_processor: Required for VL renderers.
127+
model_name: Model name for pickle metadata. If None, falls back to
128+
``tokenizer.name_or_path``. Provide this explicitly when the tokenizer
129+
was loaded with a remapped name (e.g., Llama 3 models).
63130
64131
Returns:
65132
A Renderer instance.
@@ -68,63 +135,98 @@ def get_renderer(
68135
ValueError: If the renderer name is unknown.
69136
AssertionError: If a VL renderer is requested without an image_processor.
70137
"""
138+
139+
def _stamp_pickle_metadata(renderer: Renderer) -> Renderer:
140+
"""Stamp renderer with metadata needed for pickle support."""
141+
renderer._renderer_name = name
142+
renderer._model_name = (
143+
model_name if model_name is not None else tokenizer.name_or_path
144+
)
145+
renderer._has_image_processor = image_processor is not None
146+
return renderer
147+
148+
# Check custom registry first
149+
if (factory := _CUSTOM_RENDERER_REGISTRY.get(name)) is not None:
150+
return _stamp_pickle_metadata(factory(tokenizer, image_processor))
151+
71152
# Import renderer classes lazily to avoid circular imports and keep exports minimal
72153
from .deepseek_v3 import DeepSeekV3DisableThinkingRenderer
73154
from .gpt_oss import GptOssRenderer
74155
from .kimi_k2 import KimiK2Renderer
156+
from .kimi_k25 import KimiK25DisableThinkingRenderer, KimiK25Renderer
75157
from .llama3 import Llama3Renderer
76158
from .qwen3 import (
77159
Qwen3DisableThinkingRenderer,
78160
Qwen3InstructRenderer,
79161
Qwen3VLInstructRenderer,
80162
Qwen3VLRenderer,
81163
)
164+
from .qwen3_5 import Qwen3_5DisableThinkingRenderer, Qwen3_5Renderer
82165
from .role_colon import RoleColonRenderer
83166

167+
renderer: Renderer
84168
if name == "role_colon":
85-
return RoleColonRenderer(tokenizer)
169+
renderer = RoleColonRenderer(tokenizer)
86170
elif name == "llama3":
87-
return Llama3Renderer(tokenizer)
171+
renderer = Llama3Renderer(tokenizer)
88172
elif name == "qwen3":
89-
return Qwen3Renderer(tokenizer)
173+
renderer = Qwen3Renderer(tokenizer)
90174
elif name == "qwen3_vl":
91175
assert image_processor is not None, (
92176
"qwen3_vl renderer requires an image_processor"
93177
)
94-
return Qwen3VLRenderer(tokenizer, image_processor)
178+
renderer = Qwen3VLRenderer(tokenizer, image_processor)
95179
elif name == "qwen3_vl_instruct":
96180
assert image_processor is not None, (
97181
"qwen3_vl_instruct renderer requires an image_processor"
98182
)
99-
return Qwen3VLInstructRenderer(tokenizer, image_processor)
183+
renderer = Qwen3VLInstructRenderer(tokenizer, image_processor)
100184
elif name == "qwen3_disable_thinking":
101-
return Qwen3DisableThinkingRenderer(tokenizer)
185+
renderer = Qwen3DisableThinkingRenderer(tokenizer)
102186
elif name == "qwen3_instruct":
103-
return Qwen3InstructRenderer(tokenizer)
187+
renderer = Qwen3InstructRenderer(tokenizer)
188+
elif name == "qwen3_5":
189+
renderer = Qwen3_5Renderer(tokenizer, image_processor=image_processor)
190+
elif name == "qwen3_5_disable_thinking":
191+
renderer = Qwen3_5DisableThinkingRenderer(
192+
tokenizer, image_processor=image_processor
193+
)
104194
elif name == "deepseekv3":
105195
# Default to non-thinking mode (matches HF template default behavior)
106-
return DeepSeekV3DisableThinkingRenderer(tokenizer)
196+
renderer = DeepSeekV3DisableThinkingRenderer(tokenizer)
107197
elif name == "deepseekv3_disable_thinking":
108198
# Alias for backward compatibility
109-
return DeepSeekV3DisableThinkingRenderer(tokenizer)
199+
renderer = DeepSeekV3DisableThinkingRenderer(tokenizer)
110200
elif name == "deepseekv3_thinking":
111-
return DeepSeekV3ThinkingRenderer(tokenizer)
201+
renderer = DeepSeekV3ThinkingRenderer(tokenizer)
112202
elif name == "kimi_k2":
113-
return KimiK2Renderer(tokenizer)
203+
renderer = KimiK2Renderer(tokenizer)
204+
elif name == "kimi_k25":
205+
renderer = KimiK25Renderer(tokenizer, image_processor=image_processor)
206+
elif name == "kimi_k25_disable_thinking":
207+
renderer = KimiK25DisableThinkingRenderer(
208+
tokenizer, image_processor=image_processor
209+
)
114210
elif name == "gpt_oss_no_sysprompt":
115-
return GptOssRenderer(tokenizer, use_system_prompt=False)
211+
renderer = GptOssRenderer(tokenizer, use_system_prompt=False)
116212
elif name == "gpt_oss_low_reasoning":
117-
return GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort="low")
213+
renderer = GptOssRenderer(
214+
tokenizer, use_system_prompt=True, reasoning_effort="low"
215+
)
118216
elif name == "gpt_oss_medium_reasoning":
119-
return GptOssRenderer(
217+
renderer = GptOssRenderer(
120218
tokenizer, use_system_prompt=True, reasoning_effort="medium"
121219
)
122220
elif name == "gpt_oss_high_reasoning":
123-
return GptOssRenderer(
221+
renderer = GptOssRenderer(
124222
tokenizer, use_system_prompt=True, reasoning_effort="high"
125223
)
126224
else:
127-
raise ValueError(f"Unknown renderer: {name}")
225+
raise ValueError(
226+
f"Unknown renderer: {name}. If this is a custom renderer, please register it via register_renderer()."
227+
)
228+
229+
return _stamp_pickle_metadata(renderer)
128230

129231

130232
__all__ = [
@@ -137,6 +239,12 @@ def get_renderer(
137239
"ThinkingPart",
138240
"ToolCall",
139241
"ToolSpec",
242+
# Streaming types
243+
"MessageDelta",
244+
"StreamingMessageHeader",
245+
"StreamingTextDelta",
246+
"StreamingThinkingDelta",
247+
"Utf8TokenDecoder",
140248
# Renderer base
141249
"RenderContext",
142250
"Renderer",
@@ -146,6 +254,11 @@ def get_renderer(
146254
"format_content_as_string",
147255
"get_text_content",
148256
"parse_content_blocks",
257+
# Registry
258+
"register_renderer",
259+
"unregister_renderer",
260+
"get_registered_renderer_names",
261+
"is_renderer_registered",
149262
# Factory
150263
"get_renderer",
151264
# Renderer classes (used by tests)

0 commit comments

Comments
 (0)