Skip to content

Commit 565fd8d

Browse files
authored
refactor: improve initialization code (#4)
Signed-off-by: Will Killian <william.killian@outlook.com>
1 parent 3f0f345 commit 565fd8d

8 files changed

Lines changed: 388 additions & 215 deletions

File tree

example.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,6 @@
408408
"faculty_lab",
409409
"same_room",
410410
"same_lab",
411-
"pack_rooms",
412-
"pack_labs"
411+
"pack_rooms"
413412
]
414413
}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ license-files = ["LICENSE*"]
1111
readme = "README.md"
1212
requires-python = ">=3.12"
1313
dependencies = [
14+
"bidict~=0.23.1",
1415
"click~=8.2.1",
1516
"fastapi~=0.116.1",
1617
"requests~=2.32.5",

src/scheduler/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def main(
5454
if limit is not None:
5555
full_config.limit = limit
5656
limit = full_config.limit
57-
if optimizer_flags is not None:
57+
if optimizer_flags:
5858
full_config.optimizer_flags = optimizer_flags
5959

6060
logger.info(f"Using limit={limit}")

src/scheduler/models/course.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,48 @@
1+
from dataclasses import dataclass
2+
13
import z3
24
from pydantic import BaseModel, ConfigDict, Field, computed_field
35

46
from .time_slot import TimeInstance, TimeSlot
57

68

7-
class Course(BaseModel):
9+
@dataclass
10+
class Course:
811
"""
912
A course with a course_id, section, credits, conflicts, potential labs, potential rooms, and potential faculty.
1013
"""
1114

12-
model_config = ConfigDict(extra="forbid", strict=True, arbitrary_types_allowed=True)
13-
"""
14-
Configuration for the model which forbids extra fields and is strict (@private)
15-
"""
16-
17-
course_id: str = Field(description="The unique identifier for the course")
15+
course_id: str
1816
"""
1917
The unique identifier for the course
2018
"""
2119

22-
credits: int = Field(description="The number of credits for the course")
20+
credits: int
2321
"""
2422
The number of credits for the course
2523
"""
2624

27-
section: int = Field(description="The section number for the course")
25+
section: int
2826
"""
2927
The section number for the course
3028
"""
3129

32-
labs: list[str] = Field(default_factory=list, description="The list of potential labs for the course")
30+
labs: list[str]
3331
"""
3432
The list of potential labs for the course
3533
"""
3634

37-
rooms: list[str] = Field(default_factory=list, description="The list of potential rooms for the course")
35+
rooms: list[str]
3836
"""
3937
The list of potential rooms for the course
4038
"""
4139

42-
conflicts: list[str] = Field(default_factory=list, description="The list of course conflicts for the course")
40+
conflicts: list[str]
4341
"""
4442
The list of course conflicts for the course
4543
"""
4644

47-
faculties: list[str] = Field(default_factory=list, description="The list of potential faculty for the course")
45+
faculties: list[str]
4846
"""
4947
The list of potential faculty for the course
5048
"""
@@ -81,7 +79,7 @@ class CourseInstance(BaseModel):
8179
A course instance with a course, time, faculty, room, and lab.
8280
"""
8381

84-
model_config = ConfigDict(extra="forbid", strict=True)
82+
model_config = ConfigDict(extra="forbid", strict=True, arbitrary_types_allowed=True)
8583
"""
8684
Configuration for the model which forbids extra fields and is strict (@private)
8785
"""

src/scheduler/models/time_slot.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Self
2+
13
from pydantic import BaseModel, ConfigDict, Field, model_serializer
24

35
from .day import Day
@@ -35,16 +37,16 @@ def value(self) -> int:
3537
def __abs__(self) -> "Duration":
3638
return Duration(duration=abs(self.value))
3739

38-
def __lt__(self, other: "Duration") -> bool:
40+
def __lt__(self, other: Self) -> bool:
3941
return self.value < other.value
4042

41-
def __le__(self, other: "Duration") -> bool:
43+
def __le__(self, other: Self) -> bool:
4244
return self.value <= other.value
4345

44-
def __gt__(self, other: "Duration") -> bool:
46+
def __gt__(self, other: Self) -> bool:
4547
return self.value > other.value
4648

47-
def __ge__(self, other: "Duration") -> bool:
49+
def __ge__(self, other: Self) -> bool:
4850
return self.value >= other.value
4951

5052
def __eq__(self, other: object) -> bool:
@@ -117,22 +119,22 @@ def value(self) -> int:
117119
def __add__(self, dur: Duration) -> "TimePoint":
118120
return TimePoint(timepoint=(self.value + dur.value))
119121

120-
def __sub__(self, other: "TimePoint") -> Duration:
122+
def __sub__(self, other: Self) -> Duration:
121123
return Duration(duration=(self.value - other.value))
122124

123125
def __abs__(self) -> Duration:
124126
return Duration(duration=abs(self.value))
125127

126-
def __lt__(self, other: "TimePoint") -> bool:
128+
def __lt__(self, other: Self) -> bool:
127129
return self.value < other.value
128130

129-
def __le__(self, other: "TimePoint") -> bool:
131+
def __le__(self, other: Self) -> bool:
130132
return self.value <= other.value
131133

132-
def __gt__(self, other: "TimePoint") -> bool:
134+
def __gt__(self, other: Self) -> bool:
133135
return self.value > other.value
134136

135-
def __ge__(self, other: "TimePoint") -> bool:
137+
def __ge__(self, other: Self) -> bool:
136138
return self.value >= other.value
137139

138140
def __eq__(self, other: object) -> bool:
@@ -198,11 +200,6 @@ class TimeSlot(BaseModel):
198200
Configuration for the model which allows extra fields and is not strict (@private)
199201
"""
200202

201-
id: int = Field(description="The unique identifier for the time slot")
202-
"""
203-
The unique identifier for the time slot
204-
"""
205-
206203
times: list[TimeInstance] = Field(description="The list of time instances in the time slot")
207204
"""
208205
The list of time instances in the time slot
@@ -219,7 +216,10 @@ class TimeSlot(BaseModel):
219216
"""
220217

221218
def __hash__(self) -> int:
222-
return hash(self.id)
219+
"""
220+
Hash the time slot by its string representation
221+
"""
222+
return hash(str(self))
223223

224224
def lab_time(self) -> TimeInstance | None:
225225
"""

0 commit comments

Comments
 (0)