Skip to content

Commit eb83d1b

Browse files
committed
Cleaned llm/ type errors
1 parent 71d00f0 commit eb83d1b

File tree

9 files changed

+149
-66
lines changed

9 files changed

+149
-66
lines changed

nemoguardrails/llm/filters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def to_messages(colang_history: str) -> List[dict]:
150150
# a message from the user, and the rest gets translated to messages from the assistant.
151151
lines = colang_history.split("\n")
152152

153-
bot_lines = []
153+
bot_lines: list[str] = []
154154
for i, line in enumerate(lines):
155155
if line.startswith('user "'):
156156
# If we have bot lines in the buffer, we first add a bot message.
@@ -191,8 +191,8 @@ def to_messages_v2(colang_history: str) -> List[dict]:
191191
# a message from the user, and the rest gets translated to messages from the assistant.
192192
lines = colang_history.split("\n")
193193

194-
user_lines = []
195-
bot_lines = []
194+
user_lines: list[str] = []
195+
bot_lines: list[str] = []
196196
for line in lines:
197197
if line.startswith("user action:"):
198198
if len(bot_lines) > 0:
@@ -285,7 +285,7 @@ def verbose_v1(colang_history: str) -> str:
285285
return "\n".join(lines)
286286

287287

288-
def to_chat_messages(events: List[dict]) -> str:
288+
def to_chat_messages(events: List[dict]) -> List[dict]:
289289
"""Filter that turns an array of events into a sequence of user/assistant messages.
290290
291291
Properly handles multimodal content by preserving the structure when the content

nemoguardrails/llm/helpers.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import List, Optional, Type, Union
16+
from typing import List, Optional, Type
1717

1818
from langchain.callbacks.manager import (
1919
AsyncCallbackManagerForLLMRun,
2020
CallbackManagerForLLMRun,
2121
)
22-
from langchain_core.language_models.llms import LLM, BaseLLM
22+
from langchain_core.language_models.llms import LLM
2323

2424

25-
def get_llm_instance_wrapper(
26-
llm_instance: Union[LLM, BaseLLM], llm_type: str
27-
) -> Type[LLM]:
25+
def get_llm_instance_wrapper(llm_instance: LLM, llm_type: str) -> Type[LLM]:
2826
"""Wraps an LLM instance in a class that can be registered with LLMRails.
2927
3028
This is useful to create specific types of LLMs using a generic LLM provider
@@ -47,7 +45,7 @@ def model_kwargs(self):
4745
These are needed to allow changes to the arguments of the LLM calls.
4846
"""
4947
if hasattr(llm_instance, "model_kwargs"):
50-
return llm_instance.model_kwargs
48+
return getattr(llm_instance, "model_kwargs")
5149
return {}
5250

5351
@property
@@ -66,26 +64,29 @@ def _modify_instance_kwargs(self):
6664
"""
6765

6866
if hasattr(llm_instance, "model_kwargs"):
69-
if isinstance(llm_instance.model_kwargs, dict):
70-
llm_instance.model_kwargs["temperature"] = self.temperature
71-
llm_instance.model_kwargs["streaming"] = self.streaming
67+
model_kwargs = getattr(llm_instance, "model_kwargs")
68+
if isinstance(model_kwargs, dict):
69+
model_kwargs["temperature"] = self.temperature
70+
model_kwargs["streaming"] = self.streaming
7271

7372
def _call(
7473
self,
7574
prompt: str,
7675
stop: Optional[List[str]] = None,
7776
run_manager: Optional[CallbackManagerForLLMRun] = None,
77+
**kwargs,
7878
) -> str:
7979
self._modify_instance_kwargs()
80-
return llm_instance._call(prompt, stop, run_manager)
80+
return llm_instance._call(prompt, stop, run_manager, **kwargs)
8181

8282
async def _acall(
8383
self,
8484
prompt: str,
8585
stop: Optional[List[str]] = None,
8686
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
87+
**kwargs,
8788
) -> str:
8889
self._modify_instance_kwargs()
89-
return await llm_instance._acall(prompt, stop, run_manager)
90+
return await llm_instance._acall(prompt, stop, run_manager, **kwargs)
9091

9192
return WrapperLLM

nemoguardrails/llm/models/initializer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
from langchain_core.language_models import BaseChatModel
2121
from langchain_core.language_models.llms import BaseLLM
2222

23-
from .langchain_initializer import ModelInitializationError, init_langchain_model
23+
from nemoguardrails.llm.models.langchain_initializer import (
24+
ModelInitializationError,
25+
init_langchain_model,
26+
)
2427

2528

2629
# later we can easily conver it to a class
2730
def init_llm_model(
28-
model_name: Optional[str],
31+
model_name: str,
2932
provider_name: str,
3033
mode: Literal["chat", "text"],
3134
kwargs: Dict[str, Any],

nemoguardrails/llm/params.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"""
2121

2222
import logging
23-
from typing import Dict, Type
23+
from typing import Any, Dict, Type
2424

2525
from langchain.base_language import BaseLanguageModel
2626

@@ -33,18 +33,18 @@ class LLMParams:
3333
def __init__(self, llm: BaseLanguageModel, **kwargs):
3434
self.llm = llm
3535
self.altered_params = kwargs
36-
self.original_params = {}
36+
self.original_params: dict[str, Any] = {}
3737

3838
def __enter__(self):
3939
# Here we can access and modify the global language model parameters.
40-
self.original_params = {}
4140
for param, value in self.altered_params.items():
4241
if hasattr(self.llm, param):
4342
self.original_params[param] = getattr(self.llm, param)
4443
setattr(self.llm, param, value)
4544

4645
elif hasattr(self.llm, "model_kwargs"):
47-
if param not in self.llm.model_kwargs:
46+
model_kwargs = getattr(self.llm, "model_kwargs", {})
47+
if param not in model_kwargs:
4848
log.warning(
4949
"Parameter %s does not exist for %s. Passing to model_kwargs",
5050
param,
@@ -53,9 +53,10 @@ def __enter__(self):
5353

5454
self.original_params[param] = None
5555
else:
56-
self.original_params[param] = self.llm.model_kwargs[param]
56+
self.original_params[param] = model_kwargs[param]
5757

58-
self.llm.model_kwargs[param] = value
58+
model_kwargs[param] = value
59+
setattr(self.llm, "model_kwargs", model_kwargs)
5960

6061
else:
6162
log.warning(
@@ -64,7 +65,7 @@ def __enter__(self):
6465
self.llm.__class__.__name__,
6566
)
6667

67-
def __exit__(self, type, value, traceback):
68+
def __exit__(self, exc_type, value, traceback):
6869
# Restore original parameters when exiting the context
6970
for param, value in self.original_params.items():
7071
if hasattr(self.llm, param):

nemoguardrails/llm/providers/huggingface/pipeline.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,33 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import asyncio
1617
from typing import Any, List, Optional
1718

1819
from langchain.callbacks.manager import (
1920
AsyncCallbackManagerForLLMRun,
2021
CallbackManagerForLLMRun,
2122
)
2223
from langchain.schema.output import GenerationChunk
23-
from langchain_community.llms import HuggingFacePipeline
24+
25+
# Import HuggingFacePipeline with fallbacks for different LangChain versions
26+
HuggingFacePipeline = None # type: ignore[assignment]
27+
28+
try:
29+
from langchain_community.llms import (
30+
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
31+
)
32+
except ImportError:
33+
# Fallback for older versions of langchain
34+
try:
35+
from langchain.llms import (
36+
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
37+
)
38+
except ImportError:
39+
# Create a dummy class if HuggingFacePipeline is not available
40+
class HuggingFacePipeline: # type: ignore[misc,no-redef]
41+
def __init__(self, *args, **kwargs):
42+
raise ImportError("HuggingFacePipeline is not available")
2443

2544

2645
class HuggingFacePipelineCompatible(HuggingFacePipeline):
@@ -47,12 +66,13 @@ def _call(
4766
)
4867

4968
# Streaming for NeMo Guardrails is not supported in sync calls.
50-
if self.model_kwargs and self.model_kwargs.get("streaming"):
51-
raise Exception(
69+
model_kwargs = getattr(self, "model_kwargs", {})
70+
if model_kwargs and model_kwargs.get("streaming"):
71+
raise NotImplementedError(
5272
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!"
5373
)
5474

55-
llm_result = self._generate(
75+
llm_result = getattr(self, "_generate")(
5676
[prompt],
5777
stop=stop,
5878
run_manager=run_manager,
@@ -78,11 +98,12 @@ async def _acall(
7898
)
7999

80100
# Handle streaming, if the flag is set
81-
if self.model_kwargs and self.model_kwargs.get("streaming"):
101+
model_kwargs = getattr(self, "model_kwargs", {})
102+
if model_kwargs and model_kwargs.get("streaming"):
82103
# Retrieve the streamer object, needs to be set in model_kwargs
83-
streamer = self.model_kwargs.get("streamer")
104+
streamer = model_kwargs.get("streamer")
84105
if not streamer:
85-
raise Exception(
106+
raise ValueError(
86107
"Cannot stream, please add HuggingFace streamer object to model_kwargs!"
87108
)
88109

@@ -99,7 +120,7 @@ async def _acall(
99120
run_manager=run_manager,
100121
**kwargs,
101122
)
102-
loop.create_task(self._agenerate(**generation_kwargs))
123+
loop.create_task(getattr(self, "_agenerate")(**generation_kwargs))
103124

104125
# And start waiting for the chunks to come in.
105126
completion = ""
@@ -111,7 +132,7 @@ async def _acall(
111132

112133
return completion
113134

114-
llm_result = await self._agenerate(
135+
llm_result = await getattr(self, "_agenerate")(
115136
[prompt],
116137
stop=stop,
117138
run_manager=run_manager,

nemoguardrails/llm/providers/huggingface/streamers.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,27 @@
1414
# limitations under the License.
1515

1616
import asyncio
17+
from typing import TYPE_CHECKING, Optional
1718

18-
from transformers.generation.streamers import TextStreamer
19+
TRANSFORMERS_AVAILABLE = True
20+
try:
21+
from transformers.generation.streamers import ( # type: ignore[import-untyped]
22+
TextStreamer,
23+
)
24+
except ImportError:
25+
# Fallback if transformers is not available
26+
TRANSFORMERS_AVAILABLE = False
1927

28+
class TextStreamer: # type: ignore[no-redef]
29+
def __init__(self, *args, **kwargs):
30+
pass
2031

21-
class AsyncTextIteratorStreamer(TextStreamer):
32+
33+
if TYPE_CHECKING:
34+
from transformers import AutoTokenizer # type: ignore[import-untyped]
35+
36+
37+
class AsyncTextIteratorStreamer(TextStreamer): # type: ignore[misc]
2238
"""
2339
Simple async implementation for HuggingFace Transformers streamers.
2440
@@ -30,12 +46,14 @@ def __init__(
3046
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs
3147
):
3248
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
33-
self.text_queue = asyncio.Queue()
49+
self.text_queue: asyncio.Queue[str] = asyncio.Queue()
3450
self.stop_signal = None
35-
self.loop = None
51+
self.loop: Optional[asyncio.AbstractEventLoop] = None
3652

3753
def on_finalized_text(self, text: str, stream_end: bool = False):
3854
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
55+
if self.loop is None:
56+
return
3957
if len(text) > 0:
4058
asyncio.run_coroutine_threadsafe(self.text_queue.put(text), self.loop)
4159

0 commit comments

Comments
 (0)