Skip to content

Commit e80debf

Browse files
committed
Revert model initializer order, but only try _init_text_completion_model for text models
1 parent 506007e commit e80debf

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

nemoguardrails/llm/models/langchain_initializer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,12 @@ def init_langchain_model(
142142
initializers: list[ModelInitializer] = [
143143
# Try special case handlers first (handles both chat and text)
144144
ModelInitializer(_handle_model_special_cases, ["chat", "text"]),
145-
# FIXME: is text and chat a good idea?
146-
# For text mode, use text completion, we are using both text and chat as the last resort
147-
ModelInitializer(_init_text_completion_model, ["text", "chat"]),
148145
# For chat mode, first try the standard chat completion API
149146
ModelInitializer(_init_chat_completion_model, ["chat"]),
150147
# For chat mode, fall back to community chat models
151148
ModelInitializer(_init_community_chat_models, ["chat"]),
149+
# For text mode, use text completion
150+
ModelInitializer(_init_text_completion_model, ["text"]),
152151
]
153152

154153
# Track the last exception for better error reporting

tests/llm_providers/test_langchain_initializer.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,27 +70,25 @@ def test_special_case_called_first(mock_initializers):
7070

7171
def test_chat_completion_called(mock_initializers):
7272
mock_initializers["special"].return_value = None
73-
mock_initializers["text"].return_value = None
7473
mock_initializers["chat"].return_value = "chat_model"
7574
result = init_langchain_model("chat-model", "provider", "chat", {})
7675
assert result == "chat_model"
7776
mock_initializers["special"].assert_called_once()
78-
mock_initializers["text"].assert_called_once()
7977
mock_initializers["chat"].assert_called_once()
8078
mock_initializers["community"].assert_not_called()
79+
mock_initializers["text"].assert_not_called()
8180

8281

8382
def test_community_chat_called(mock_initializers):
8483
mock_initializers["special"].return_value = None
85-
mock_initializers["text"].return_value = None
8684
mock_initializers["chat"].return_value = None
8785
mock_initializers["community"].return_value = "community_model"
8886
result = init_langchain_model("community-chat", "provider", "chat", {})
8987
assert result == "community_model"
9088
mock_initializers["special"].assert_called_once()
91-
mock_initializers["text"].assert_called_once()
9289
mock_initializers["chat"].assert_called_once()
9390
mock_initializers["community"].assert_called_once()
91+
mock_initializers["text"].assert_not_called()
9492

9593

9694
def test_text_completion_called(mock_initializers):
@@ -110,13 +108,12 @@ def test_all_initializers_fail(mock_initializers):
110108
mock_initializers["special"].return_value = None
111109
mock_initializers["chat"].return_value = None
112110
mock_initializers["community"].return_value = None
113-
mock_initializers["text"].return_value = None
114111
with pytest.raises(ModelInitializationError):
115112
init_langchain_model("unknown-model", "provider", "chat", {})
116113
mock_initializers["special"].assert_called_once()
117114
mock_initializers["chat"].assert_called_once()
118115
mock_initializers["community"].assert_called_once()
119-
mock_initializers["text"].assert_called_once()
116+
mock_initializers["text"].assert_not_called()
120117

121118

122119
def test_unsupported_mode(mock_initializers):
@@ -151,44 +148,41 @@ def test_all_initializers_raise_exceptions(mock_initializers):
151148
mock_initializers["special"].assert_called_once()
152149
mock_initializers["chat"].assert_called_once()
153150
mock_initializers["community"].assert_called_once()
154-
mock_initializers["text"].assert_called_once()
151+
mock_initializers["text"].assert_not_called()
155152

156153

157154
def test_duplicate_modes_in_initializer(mock_initializers):
158155
mock_initializers["special"].return_value = None
159-
mock_initializers["text"].return_value = None
160156
mock_initializers["chat"].return_value = "chat_model"
161157
result = init_langchain_model("chat-model", "provider", "chat", {})
162158
assert result == "chat_model"
163159
mock_initializers["special"].assert_called_once()
164-
mock_initializers["text"].assert_called_once()
165160
mock_initializers["chat"].assert_called_once()
166161
mock_initializers["community"].assert_not_called()
162+
mock_initializers["text"].assert_not_called()
167163

168164

169165
def test_chat_completion_called_when_special_returns_none(mock_initializers):
170166
mock_initializers["special"].return_value = None
171-
mock_initializers["text"].return_value = None
172167
mock_initializers["chat"].return_value = "chat_model"
173168
result = init_langchain_model("chat-model", "provider", "chat", {})
174169
assert result == "chat_model"
175170
mock_initializers["special"].assert_called_once()
176-
mock_initializers["text"].assert_called_once()
177171
mock_initializers["chat"].assert_called_once()
178172
mock_initializers["community"].assert_not_called()
173+
mock_initializers["text"].assert_not_called()
179174

180175

181176
def test_community_chat_called_when_previous_fail(mock_initializers):
182177
mock_initializers["special"].return_value = None
183-
mock_initializers["text"].return_value = None
184178
mock_initializers["chat"].return_value = None
185179
mock_initializers["community"].return_value = "community_model"
186180
result = init_langchain_model("community-chat", "provider", "chat", {})
187181
assert result == "community_model"
188182
mock_initializers["special"].assert_called_once()
189-
mock_initializers["text"].assert_called_once()
190183
mock_initializers["chat"].assert_called_once()
191184
mock_initializers["community"].assert_called_once()
185+
mock_initializers["text"].assert_not_called()
192186

193187

194188
def test_text_completion_called_when_previous_fail(mock_initializers):
@@ -204,10 +198,11 @@ def test_text_completion_called_when_previous_fail(mock_initializers):
204198
mock_initializers["text"].assert_called_once()
205199

206200

207-
def test_text_completion_supports_chat_mode(mock_initializers):
201+
def test_text_mode_only_calls_text_initializers(mock_initializers):
202+
"""Test that text mode only tries initializers that support text mode."""
208203
mock_initializers["special"].return_value = None
209204
mock_initializers["text"].return_value = "text_model"
210-
result = init_langchain_model("text-model", "provider", "chat", {})
205+
result = init_langchain_model("text-model", "provider", "text", {})
211206
assert result == "text_model"
212207
mock_initializers["special"].assert_called_once()
213208
mock_initializers["text"].assert_called_once()

0 commit comments

Comments
 (0)