Skip to content

Commit 125df51

Browse files
author
Sam Partee
authored
Run mypy/black on commit (#56)
Run the following on commit - ``mypy`` - ``black`` - ``isort`` Also lots of fixes to satisfy mypy
1 parent e5c7579 commit 125df51

File tree

12 files changed

+222
-105
lines changed

12 files changed

+222
-105
lines changed

.github/workflows/lint.yml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
2+
name: check
3+
4+
on:
5+
pull_request:
6+
push:
7+
branches:
8+
- main
9+
10+
jobs:
11+
check:
12+
name: Style-check ${{ matrix.python-version }}
13+
runs-on: ubuntu-latest
14+
strategy:
15+
matrix:
16+
# Only lint on the min and max supported Python versions.
17+
# It's extremely unlikely that there's a lint issue on any version in between
18+
# that doesn't show up on the min or max versions.
19+
#
20+
# GitHub rate-limits how many jobs can be running at any one time.
21+
# Starting new jobs is also relatively slow,
22+
# so linting on fewer versions makes CI faster.
23+
python-version:
24+
- "3.8"
25+
- "3.11"
26+
27+
steps:
28+
- uses: actions/checkout@v2
29+
- name: Set up Python ${{ matrix.python-version }}
30+
uses: actions/setup-python@v2
31+
with:
32+
python-version: ${{ matrix.python-version }}
33+
- name: Install dependencies
34+
run: |
35+
python -m pip install --upgrade pip
36+
pip install .[dev,all]
37+
38+
- name: check-sort-import
39+
run: |
40+
make check-sort-imports
41+
42+
- name: check-black-format
43+
run: |
44+
make check-format
45+
46+
- name: check-mypy
47+
run: |
48+
make mypy

Makefile

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,32 @@ help:
1919
# help: Style
2020
# help: -------
2121

22-
# help: style - Sort imports and format with black
23-
.PHONY: style
24-
style: sort-imports format
25-
26-
27-
# help: check-style - check code style compliance
28-
.PHONY: check-style
29-
check-style: check-sort-imports check-format
30-
22+
# help: check - run all checks for a commit
23+
.PHONY: check
24+
check: check-format check-sort-imports mypy
3125

3226
# help: format - perform code style format
3327
.PHONY: format
34-
format:
28+
format: sort-imports
3529
@black ./redisvl ./tests/
3630

3731

32+
# help: check-format - check code format compliance
33+
.PHONY: check-format
34+
check-format:
35+
@black --check ./redisvl
36+
37+
3838
# help: sort-imports - apply import sort ordering
3939
.PHONY: sort-imports
4040
sort-imports:
4141
@isort ./redisvl ./tests/ --profile black
4242

43+
# help: check-sort-imports - check imports are sorted
44+
.PHONY: check-sort-imports
45+
check-sort-imports:
46+
@isort ./redisvl --check-only --profile black
47+
4348

4449
# help: check-lint - run static analysis checks
4550
.PHONY: check-lint

redisvl/index.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
self,
3030
name: str,
3131
prefix: str = "rvl",
32-
storage_type: Optional[str] = "hash",
32+
storage_type: str = "hash",
3333
fields: Optional[List["Field"]] = None,
3434
):
3535
self._name = name
@@ -62,7 +62,7 @@ def search(self, *args, **kwargs) -> List["Result"]:
6262
Returns:
6363
List[Result]: A list of search results
6464
"""
65-
results: List["Result"] = self._redis_conn.ft(self._name).search(
65+
results: List["Result"] = self._redis_conn.ft(self._name).search( # type: ignore
6666
*args, **kwargs
6767
)
6868
return results
@@ -148,7 +148,7 @@ def disconnect(self):
148148
"""Disconnect from the Redis instance"""
149149
self._redis_conn = None
150150

151-
def _get_key(self, record: Dict[str, Any], key_field: str = None) -> str:
151+
def _get_key(self, record: Dict[str, Any], key_field: Optional[str] = None) -> str:
152152
"""Construct the Redis HASH top level key.
153153
154154
Args:
@@ -236,7 +236,7 @@ def __init__(
236236
self,
237237
name: str,
238238
prefix: str = "rvl",
239-
storage_type: Optional[str] = "hash",
239+
storage_type: str = "hash",
240240
fields: Optional[List["Field"]] = None,
241241
):
242242
super().__init__(name, prefix, storage_type, fields)
@@ -313,7 +313,7 @@ def create(self, overwrite: Optional[bool] = False):
313313
# set storage_type, default to hash
314314
storage_type = IndexType.HASH
315315
if self._storage.lower() == "json":
316-
self._storage = IndexType.JSON
316+
storage_type = IndexType.JSON
317317

318318
# Create Index
319319
# will raise correct response error if index already exists
@@ -358,7 +358,7 @@ def load(
358358

359359
# Check if outer interface passes in TTL on load
360360
ttl = kwargs.get("ttl")
361-
with self._redis_conn.pipeline(transaction=False) as pipe:
361+
with self._redis_conn.pipeline(transaction=False) as pipe: # type: ignore
362362
for record in data:
363363
key = self._get_key(record, key_field)
364364
pipe.hset(key, mapping=record) # type: ignore
@@ -394,7 +394,7 @@ def __init__(
394394
self,
395395
name: str,
396396
prefix: str = "rvl",
397-
storage_type: Optional[str] = "hash",
397+
storage_type: str = "hash",
398398
fields: Optional[List["Field"]] = None,
399399
):
400400
super().__init__(name, prefix, storage_type, fields)
@@ -467,7 +467,7 @@ async def create(self, overwrite: Optional[bool] = False):
467467
# set storage_type, default to hash
468468
storage_type = IndexType.HASH
469469
if self._storage.lower() == "json":
470-
self._storage = IndexType.JSON
470+
storage_type = IndexType.JSON
471471

472472
# Create Index
473473
await self._redis_conn.ft(self._name).create_index( # type: ignore
@@ -516,7 +516,7 @@ async def _load(record: dict):
516516
key = self._get_key(record, key_field)
517517
await self._redis_conn.hset(key, mapping=record) # type: ignore
518518
if ttl:
519-
await self._redis_conn.expire(key, ttl)
519+
await self._redis_conn.expire(key, ttl) # type: ignore
520520

521521
# gather with concurrency
522522
await asyncio.gather(*[_load(record) for record in data])

redisvl/llmcache/semantic.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import List, Optional, Union
22

3-
from redis.commands.search.field import VectorField
3+
from redis.commands.search.field import Field, VectorField
44

55
from redisvl.index import SearchIndex
66
from redisvl.llmcache.base import BaseLLMCache
@@ -14,8 +14,8 @@ class SemanticCache(BaseLLMCache):
1414
"""Cache for Large Language Models."""
1515

1616
# TODO allow for user to change default fields
17-
_vector_field_name = "prompt_vector"
18-
_default_fields = [
17+
_vector_field_name: str = "prompt_vector"
18+
_default_fields: List[Field] = [
1919
VectorField(
2020
_vector_field_name,
2121
"FLAT",
@@ -25,27 +25,27 @@ class SemanticCache(BaseLLMCache):
2525

2626
def __init__(
2727
self,
28-
index_name: Optional[str] = "cache",
29-
prefix: Optional[str] = "llmcache",
30-
threshold: Optional[float] = 0.9,
28+
index_name: str = "cache",
29+
prefix: str = "llmcache",
30+
threshold: float = 0.9,
3131
ttl: Optional[int] = None,
32-
vectorizer: Optional[BaseVectorizer] = HFTextVectorizer(
32+
vectorizer: BaseVectorizer = HFTextVectorizer(
3333
"sentence-transformers/all-mpnet-base-v2"
3434
),
35-
redis_url: Optional[str] = "redis://localhost:6379",
35+
redis_url: str = "redis://localhost:6379",
3636
connection_args: Optional[dict] = None,
3737
index: Optional[SearchIndex] = None,
3838
):
3939
"""Semantic Cache for Large Language Models.
4040
4141
Args:
42-
index_name (Optional[str], optional): The name of the index. Defaults to "cache".
43-
prefix (Optional[str], optional): The prefix for the index. Defaults to "llmcache".
44-
threshold (Optional[float], optional): Semantic threshold for the cache. Defaults to 0.9.
42+
index_name (str, optional): The name of the index. Defaults to "cache".
43+
prefix (str, optional): The prefix for the index. Defaults to "llmcache".
44+
threshold (float, optional): Semantic threshold for the cache. Defaults to 0.9.
4545
ttl (Optional[int], optional): The TTL for the cache. Defaults to None.
46-
vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache.
46+
vectorizer (BaseVectorizer, optional): The vectorizer for the cache.
4747
Defaults to HFTextVectorizer("sentence-transformers/all-mpnet-base-v2").
48-
redis_url (Optional[str], optional): The redis url. Defaults to "redis://localhost:6379".
48+
redis_url (str, optional): The redis url. Defaults to "redis://localhost:6379".
4949
connection_args (Optional[dict], optional): The connection arguments for the redis client. Defaults to None.
5050
index (Optional[SearchIndex], optional): The underlying search index to use for the semantic cache. Defaults to None.
5151

redisvl/query/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from redisvl.query.query import FilterQuery, VectorQuery, RangeQuery
1+
from redisvl.query.query import FilterQuery, RangeQuery, VectorQuery
22

33
__all__ = ["VectorQuery", "FilterQuery", "RangeQuery"]

0 commit comments

Comments
 (0)