diff --git a/.gitignore b/.gitignore index ff3cfc4..b57285a 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,7 @@ cython_debug/ *.log .env +.env.* data/ + +*.sqlite3 diff --git a/README.md b/README.md index 7ad106b..b33b7f9 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,13 @@ hatch env create hatch run dev ``` +Database migrations are handled with [Alembic](https://alembic.sqlalchemy.org). + +``` sh +alembic revision --autogenerate -m "Brief description of changes..." +alembic upgrade head +``` + Run [black](https://pypi.org/project/black/) and [isort](https://pycqa.github.io/isort/index.html) on your code: ``` sh diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..1e8c258 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,116 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..5cdf272 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,82 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +from pairbot import config as pairbot_config +from pairbot import models + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config +config.set_main_option("sqlalchemy.url", pairbot_config.DATABASE_URL) + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = models.Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/fab85b27c1d3_nuked.py b/alembic/versions/fab85b27c1d3_nuked.py new file mode 100644 index 0000000..5678b18 --- /dev/null +++ b/alembic/versions/fab85b27c1d3_nuked.py @@ -0,0 +1,64 @@ +"""nuked + +Revision ID: fab85b27c1d3 +Revises: +Create Date: 2023-11-25 23:32:30.908140 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'fab85b27c1d3' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('channel', + sa.Column('guild_id', sa.Integer(), nullable=False), + sa.Column('channel_id', sa.Integer(), nullable=False), + sa.Column('leetcode_integration', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('guild_id', 'channel_id') + ) + op.create_table('pairing', + sa.Column('channel_id', sa.Integer(), nullable=False), + sa.Column('thread_id', sa.Integer(), nullable=False), + sa.Column('user_1_id', sa.Integer(), nullable=False), + sa.Column('user_2_id', sa.Integer(), nullable=False), + sa.Column('user_3_id', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('channel_id', 'thread_id') + ) + op.create_table('schedule', + sa.Column('channel_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('available_0', sa.Boolean(), nullable=False), + sa.Column('available_1', sa.Boolean(), nullable=False), + sa.Column('available_2', sa.Boolean(), nullable=False), + sa.Column('available_3', sa.Boolean(), nullable=False), + sa.Column('available_4', sa.Boolean(), nullable=False), + sa.Column('available_5', sa.Boolean(), nullable=False), + sa.Column('available_6', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('channel_id', 'user_id') + ) + op.create_table('skip', + sa.Column('channel_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('date', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('channel_id', 'user_id', 'date') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('skip') + op.drop_table('schedule') + op.drop_table('pairing') + op.drop_table('channel') + # ### end Alembic commands ### diff --git a/pairbot/__init__.py b/pairbot/__init__.py index eb123b4..86966d1 100644 --- a/pairbot/__init__.py +++ b/pairbot/__init__.py @@ -1 +1 @@ -from .client import run +from .main import run diff --git a/pairbot/app.py b/pairbot/app.py new file mode 100644 index 0000000..fe32b47 --- /dev/null +++ b/pairbot/app.py @@ -0,0 +1,99 @@ +import pdb +import sys + +from typing import ( + Any, + Callable, + Concatenate, + Coroutine, + ParamSpec, + TypeVar, +) + +import logging +from functools import wraps +from contextvars import ContextVar + +import discord + +logger = logging.getLogger(__name__) + +# Globally accessible context variables +interaction_ctx: ContextVar[discord.Interaction[Any]] = ContextVar("interaction") + + +class InteractionContext: + def __init__(self, interaction: discord.Interaction[Any]): + self.interaction = interaction + self.token = None + + def __enter__(self) -> None: + self.token = interaction_ctx.set(self.interaction) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self.token is not None: + interaction_ctx.reset(self.token) + + +class InteractionLoggingAdapter(logging.LoggerAdapter): + """A custom formatter for logging inside an interaction context""" + def process(self, msg, kwargs): + interaction = interaction_ctx.get(None) + new_msg = "" + if interaction is not None: + new_msg += "Interaction" + if interaction.guild is not None: + new_msg += f" in guild \"{interaction.guild.name}\" ({interaction.guild.id})" + if interaction.channel is not None and isinstance(interaction.channel, discord.TextChannel): + new_msg += f" channel \"#{interaction.channel.name} ({interaction.channel.id})\"" + new_msg += f" by user \"@{interaction.user.name}\" ({interaction.user.id}): " + new_msg += str(msg) + return new_msg, kwargs + + +# Type fuckery for the command decorator +T = TypeVar("T") +P = ParamSpec("P") +CommandCallback = Callable[Concatenate[discord.Interaction[Any], P], Coroutine[Any, Any, T]] + +class DiscordApp: + """A wrapper around a discord client instance.""" + def __init__(self, intents: discord.Intents, **options: Any): + self.client = discord.Client(intents=intents, **options) + self.command_tree = discord.app_commands.CommandTree(self.client) + self.logger = InteractionLoggingAdapter(logger) + + # register on_ready event + @self.client.event + async def on_ready(): + for guild in self.client.guilds: + self.logger.info("Copying command tree to guild \"%s\" (%d)", guild.name, guild.id) + self.command_tree.copy_global_to(guild=guild) + await self.command_tree.sync(guild=guild) + self.logger.info("Application ready.") + + def command(self, **options: Any): + """Wrapper around discord slash commands.""" + def decorator(callback: CommandCallback): + @self.command_tree.command(**options) + @wraps(callback) + async def wrapper(i: discord.Interaction[Any], *args: P.args, **kwargs: P.kwargs) -> None: + assert i.command is not None + pretty_kwargs = ( + " with arguments { " + + ", ".join((f"{key}=\"{str(value)}\"" for key, value in kwargs.items())) + + " }" + if len(kwargs) > 0 else " with no arguments" + ) + try: + with InteractionContext(i): + self.logger.info("Executing slash command /%s%s", i.command.name, pretty_kwargs) + await callback(i, *args, **kwargs) + except Exception as e: + self.logger.error(e, exc_info=True) + await i.response.send_message("Pairbot broke somehow :v", ephemeral=True) + return wrapper + return decorator + + def run(self, token: str, *args, **kwargs): + self.client.run(token, *args, **kwargs) diff --git a/pairbot/client.py b/pairbot/client.py deleted file mode 100644 index 7b77418..0000000 --- a/pairbot/client.py +++ /dev/null @@ -1,384 +0,0 @@ -import json -import logging -import os -import random -import sqlite3 -from datetime import datetime -from pathlib import Path -from typing import List - -import discord -from discord import app_commands -from discord.ext import tasks -from dotenv import load_dotenv - -from .db import PairingsDB, ScheduleDB, Timeblock -from .utils import get_user_name, parse_args, read_guild_to_channel - -load_dotenv() -args = parse_args() -if args.dev: - print("Running in dev mode.") - BOT_TOKEN = os.getenv("BOT_TOKEN_DEV") - DATA_DIR = "data" - GUILDS_PATH = f"{DATA_DIR}/guilds-dev.json" - SCHEDULE_DB_PATH = f"{DATA_DIR}/schedule-dev.db" - PAIRINGS_DB_PATH = f"{DATA_DIR}/pairings-dev.db" - LOG_FILE = "pairbot-dev.log" -else: - print("Running in prod mode.") - BOT_TOKEN = os.getenv("BOT_TOKEN") - DATA_DIR = "data" - GUILDS_PATH = f"{DATA_DIR}/guilds.json" - SCHEDULE_DB_PATH = f"{DATA_DIR}/schedule.db" - PAIRINGS_DB_PATH = f"{DATA_DIR}/pairings.db" - LOG_FILE = "pairbot.log" -SORRY = "Unexpected error." -Path(DATA_DIR).mkdir(parents=True, exist_ok=True) - -logging.basicConfig( - filename=LOG_FILE, - filemode="a", - format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", - datefmt="%H:%M:%S", - level=logging.DEBUG, -) -logger = logging.getLogger("pairbot") - -intents = discord.Intents.default() -intents.members = True -intents.message_content = True -client = discord.Client(intents=intents) -tree = app_commands.CommandTree(client) -db = ScheduleDB(SCHEDULE_DB_PATH) -pairings_db = PairingsDB(PAIRINGS_DB_PATH) - - -# Discord API currently doesn't support variadic arguments -# https://github.com/discord/discord-api-docs/discussions/3286 -@tree.command( - name="subscribe", - description="Add timeblocks to find a partner for pair programming. \ -Matches go out at 8am UTC that day.", -) -@app_commands.describe( - timeblock="Choose WEEK to get a partner for the whole week (pairs announced Monday UTC)." -) -@app_commands.choices( - timeblock=[ - app_commands.Choice(name=Timeblock.WEEK.name, value=Timeblock.WEEK.value), - app_commands.Choice(name=Timeblock.Monday.name, value=Timeblock.Monday.value), - app_commands.Choice(name=Timeblock.Tuesday.name, value=Timeblock.Tuesday.value), - app_commands.Choice( - name=Timeblock.Wednesday.name, value=Timeblock.Wednesday.value - ), - app_commands.Choice( - name=Timeblock.Thursday.name, value=Timeblock.Thursday.value - ), - app_commands.Choice(name=Timeblock.Friday.name, value=Timeblock.Friday.value), - app_commands.Choice( - name=Timeblock.Saturday.name, value=Timeblock.Saturday.value - ), - app_commands.Choice(name=Timeblock.Sunday.name, value=Timeblock.Sunday.value), - ] -) -async def _subscribe(interaction: discord.Interaction, timeblock: Timeblock): - try: - db.insert(interaction.guild_id, interaction.user.id, timeblock) - timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} subscribed T:{timeblock.name}." - ) - msg = ( - f"Your new schedule is `{Timeblock.generate_schedule(timeblocks)}`. " - f"You can call `/subscribe` again to sign up for more days." - ) - await interaction.response.send_message(msg, ephemeral=True) - except sqlite3.IntegrityError as e: - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} failed subscribe T:{timeblock.name}." - ) - logger.warning(e, exc_info=True) - msg = ( - f"You are already subscribed to {timeblock}. " - f"Call `/unsubscribe` to remove a subscription or `/schedule` to view your schedule." - ) - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command(name="unsubscribe", description="Remove timeblocks for pair programming.") -@app_commands.describe(timeblock="Call `/unsubscribe-all` to remove all timeblocks.") -@app_commands.choices( - timeblock=[ - app_commands.Choice(name=Timeblock.WEEK.name, value=Timeblock.WEEK.value), - app_commands.Choice(name=Timeblock.Monday.name, value=Timeblock.Monday.value), - app_commands.Choice(name=Timeblock.Tuesday.name, value=Timeblock.Tuesday.value), - app_commands.Choice( - name=Timeblock.Wednesday.name, value=Timeblock.Wednesday.value - ), - app_commands.Choice( - name=Timeblock.Thursday.name, value=Timeblock.Thursday.value - ), - app_commands.Choice(name=Timeblock.Friday.name, value=Timeblock.Friday.value), - app_commands.Choice( - name=Timeblock.Saturday.name, value=Timeblock.Saturday.value - ), - app_commands.Choice(name=Timeblock.Sunday.name, value=Timeblock.Sunday.value), - ] -) -async def _unsubscribe(interaction: discord.Interaction, timeblock: Timeblock): - try: - db.delete(interaction.guild_id, interaction.user.id, timeblock) - timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} unsubscribed T:{timeblock.name}." - ) - msg = f"Your new schedule is `{Timeblock.generate_schedule(timeblocks)}`." - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command( - name="unsubscribe-all", description="Remove all timeblocks for pair programming." -) -async def _unsubscribe_all(interaction: discord.Interaction): - try: - db.unsubscribe(interaction.guild_id, interaction.user.id) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} called unsubscribe-all." - ) - msg = "Your pairing subscriptions have been removed. To rejoin, call `/subscribe` again." - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command(name="schedule", description="View your pairing schedule.") -async def _schedule(interaction: discord.Interaction): - try: - timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) - schedule = Timeblock.generate_schedule(timeblocks) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} queried schedule {schedule}." - ) - msg = ( - f"Your current schedule is `{schedule}`. " - "You can call `/subscribe` or `/unsubscribe` to modify it." - ) - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command( - name="set-channel", description="Set a channel for bot messages (admin only)." -) -@app_commands.checks.has_permissions(administrator=True) -async def _set_channel(interaction: discord.Interaction, channel: discord.TextChannel): - try: - guild_to_channel = read_guild_to_channel(GUILDS_PATH) - guild_to_channel[str(interaction.guild_id)] = channel.id - with open(GUILDS_PATH, "w") as f: - json.dump(guild_to_channel, f) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} set-channel C:{channel.id}." - ) - msg = f"Successfully set bot channel to `{channel.name}`." - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command( - name="pairwith", - description="Start an immediate pairing session with another member.", -) -async def _pairwith(interaction: discord.Interaction, user: discord.Member): - try: - guild_to_channel = read_guild_to_channel(GUILDS_PATH) - channel_id = guild_to_channel[str(interaction.guild_id)] - channel = client.get_channel(channel_id) - users = [interaction.user, user] - notify_msg = ( - f"<@{interaction.user.id}> has started an on-demand pair with you, <@{user.id}>. " - "Happy pairing! :computer:" - ) - await create_group_thread(interaction.guild_id, users, channel, notify_msg) - logger.info( - f"G:{interaction.guild_id} C:{channel.id} on-demand paired U:{interaction.user.id} with {user.id}." - ) - await interaction.response.send_message( - f"Thread with {get_user_name(user)} created in channel `{channel.name}`.", - ephemeral=True, - ) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -async def dm_user(user: discord.User, msg: str): - try: - channel = await user.create_dm() - await channel.send(msg) - except Exception as e: - logger.error(e, exc_info=True) - - -async def create_group_thread( - guild_id: int, - users: List[discord.User], - channel: discord.TextChannel, - notify_msg: str, -): - # @ notifying users in a private thread invites them - # so `notify_msg` must notify for this to work - userids = [user.id for user in users] - thread_id = pairings_db.query_userids(guild_id, userids, channel.id) - thread = None - if thread_id is not None: - logger.debug(f"Found existing thread {thread_id} for G:{guild_id} U:{userids}") - try: - guild = client.get_guild(guild_id) - thread = await guild.fetch_channel(thread_id) - except discord.errors.NotFound: - logger.debug(f"Couldn't fetch thread {thread_id}, maybe deleted?") - pairings_db.delete(guild_id, userids, channel.id, thread_id) - if thread is None: - title = ", ".join(get_user_name(user) for user in users) - thread = await channel.create_thread( - name=f"{title}", auto_archive_duration=10080 - ) - logger.debug(f"Created new thread {thread.id} for G:{guild_id} U:{userids}") - pairings_db.insert(guild_id, userids, channel.id, thread.id) - else: - logger.debug(f"Found existing thread {thread_id} for G:{guild_id} U:{userids}") - guild = client.get_guild(guild_id) - thread = await guild.fetch_channel(thread_id) - await thread.send(notify_msg) - - -async def on_tree_error( - interaction: discord.Interaction, error: app_commands.AppCommandError -): - if isinstance(error, app_commands.CommandOnCooldown): - return await interaction.response.send_message( - f"Command is currently on cooldown, try again in {error.retry_after:.2f} seconds.", - ephemeral=True, - ) - elif isinstance(error, app_commands.MissingPermissions): - return await interaction.response.send_message( - "You don't have the permissions to do that.", ephemeral=True - ) - else: - raise error - - -@tasks.loop(hours=1) -async def pairing_cron(): - def should_run(): - now = datetime.utcnow() - hour = now.time().hour - logger.debug(f"Checking pairing job at UTC:{now}.") - return hour == 8 - - if should_run(): - await run_pairing() - - -async def run_pairing(): - now = datetime.utcnow() - print(now) - logger.debug(f"Running pairing job at UTC:{now}.") - weekday = now.weekday() - weekday_map = { - 0: Timeblock.Monday, - 1: Timeblock.Tuesday, - 2: Timeblock.Wednesday, - 3: Timeblock.Thursday, - 4: Timeblock.Friday, - 5: Timeblock.Saturday, - 6: Timeblock.Sunday, - } - timeblock = weekday_map[weekday] - for guild in client.guilds: - await pair(guild.id, timeblock) - # weekly Monday match - if weekday == 0: - await pair(guild.id, Timeblock.WEEK) - - -async def pair(guild_id: int, timeblock: Timeblock): - try: - userids = db.query_timeblock(guild_id, timeblock) - users = [client.get_user(userid) for userid in userids] - # Users may leave the server without unsubscribing - # TODO: listen to that event and drop them from the table - users = list(filter(None, users)) - logger.info( - f"Pairing for G:{guild_id} T:{timeblock.name} with {len(users)}/{len(userids)} users." - ) - guild_to_channel = read_guild_to_channel(GUILDS_PATH) - channel = client.get_channel(guild_to_channel[str(guild_id)]) - if len(users) < 2: - for user in users: - logger.info( - f"G:{guild_id} T:{timeblock.name} pair failed, dming U:{user.id}." - ) - msg = ( - f"Thanks for signing up for pairing this {timeblock}. " - "Unfortunately, there was nobody else available this time." - ) - await dm_user(user, msg) - await channel.send( - f"Not enough signups this {timeblock}. Try `/subscribe` to sign up!" - ) - return - - random.shuffle(users) - groups = [users[i :: len(users) // 2] for i in range(len(users) // 2)] - for group in groups: - notify_msg = ", ".join(f"<@{user.id}>" for user in group) - notify_msg = f"{notify_msg}: you've been matched together for this {timeblock}. Happy pairing! :computer:" - await create_group_thread(guild_id, group, channel, notify_msg) - logger.info( - f"G:{guild_id} C:{channel.id} paired U:{[user.id for user in group]}." - ) - await channel.send( - f"Pairings for {len(users)} users have been sent out for this {timeblock}. Try `/subscribe` to sign up!" - ) - except Exception as e: - logger.error(e, exc_info=True) - - -def local_setup(): - try: - read_guild_to_channel(GUILDS_PATH) - except Exception: - with open(GUILDS_PATH, "w") as f: - json.dump({}, f) - - -@client.event -async def on_ready(): - local_setup() - await client.wait_until_ready() - tree.on_error = on_tree_error - for guild in client.guilds: - tree.copy_global_to(guild=guild) - await tree.sync(guild=guild) - print("Code sync complete!") - pairing_cron.start() - print("Starting cron loop...") - logger.info("Bot started.") - - -def run(): - client.run(BOT_TOKEN) diff --git a/pairbot/config.py b/pairbot/config.py new file mode 100644 index 0000000..5001e94 --- /dev/null +++ b/pairbot/config.py @@ -0,0 +1,67 @@ +"""config.py + +This module contains pairbot's configuration settings, stored as module-level variables +with uppercase names. +""" + +import os +from pathlib import Path +from datetime import time +import logging.config + +import dotenv +import discord + +BASE_DIR = Path(__file__).resolve().parent.parent + +class ConfigError(Exception): + """Exception raised for errors in application configuration.""" + pass + +ENV = os.environ.get("ENV", "development") +if ENV not in ["development", "testing", "production"]: + raise ConfigError(f"Invalid environment: {ENV}") + +DOTENV_PATH = os.path.join(BASE_DIR, f".env.{ENV}") +if not os.path.isfile(DOTENV_PATH): + raise ConfigError(f"Could not find dotenv file {DOTENV_PATH}.") +dotenv.load_dotenv(DOTENV_PATH) + +# Load required environment variables +try: + DISCORD_BOT_TOKEN = os.environ["DISCORD_BOT_TOKEN"] + DATABASE_URL = os.environ["DATABASE_URL"] + PAIRING_TIME = time.fromisoformat(os.environ["PAIRING_TIME"]) +except KeyError as e: + raise ConfigError(f"Missing environment variable: {e.args[0]}.") + +# Configure logger +LOGGING = { + "version": 1, + "formatters": { + "default": { + "format": "%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", + "datefmt": "%H:%M:%S", + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + } + }, + "root": { + "handlers": ["console"], + "level": "DEBUG", + } +} + +logging.config.dictConfig(LOGGING) + +logger = logging.getLogger(__name__) + +# Log config information +logger.info(f"Running in {ENV} mode.") + +# Disable PyNaCl warning +discord.VoiceClient.warn_nacl = False diff --git a/pairbot/db.py b/pairbot/db.py deleted file mode 100644 index a87df35..0000000 --- a/pairbot/db.py +++ /dev/null @@ -1,133 +0,0 @@ -import sqlite3 -from contextlib import closing -from enum import Enum, auto -from typing import List, Optional - - -class Timeblock(Enum): - WEEK = auto() - Monday = auto() - Tuesday = auto() - Wednesday = auto() - Thursday = auto() - Friday = auto() - Saturday = auto() - Sunday = auto() - - def __str__(self): - return self.name - - @staticmethod - def generate_schedule(timeblocks: List["Timeblock"]) -> str: - return f"{[str(block) for block in sorted(timeblocks, key=lambda block: block.value)]}" - - -class ScheduleDB: - def __init__(self, path: str) -> None: - self.db = path - self.con = sqlite3.connect(self.db) - self._setup() - - def _setup(self) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "CREATE TABLE IF NOT EXISTS users (guildid INTEGER, userid INTEGER, timeblock INTEGER, " - "UNIQUE (guildid, userid, timeblock))" - ) - self.con.commit() - - def insert(self, guild_id: int, user_id: int, timeblock: Timeblock) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "INSERT INTO users VALUES (?, ?, ?)", - (guild_id, user_id, timeblock.value), - ) - self.con.commit() - - def delete(self, guild_id: int, user_id: int, timeblock: Timeblock) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "DELETE from users WHERE guildid=? and userid=? and timeblock=?", - (guild_id, user_id, timeblock.value), - ) - self.con.commit() - - def unsubscribe(self, guild_id: int, user_id: int) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "DELETE FROM users WHERE guildid=? and userid=?", (guild_id, user_id) - ) - self.con.commit() - - def query_timeblock(self, guild_id: int, timeblock: Timeblock) -> List[int]: - with closing(self.con.cursor()) as cur: - res = cur.execute( - "SELECT userid FROM users WHERE guildid=? and timeblock=?", - (guild_id, timeblock.value), - ) - userids = list(map(lambda x: x[0], res.fetchall())) - return userids - - def query_userid(self, guild_id: int, user_id: int) -> List[Timeblock]: - with closing(self.con.cursor()) as cur: - res = cur.execute( - "SELECT timeblock FROM users WHERE guildid=? and userid=?", - (guild_id, user_id), - ) - timeblocks = list(map(lambda x: Timeblock(x[0]), res.fetchall())) - return timeblocks - - -class PairingsDB: - def __init__(self, path: str) -> None: - self.db = path - self.con = sqlite3.connect(self.db) - self._setup() - - @staticmethod - def _serialize_userids(userids: List[int]) -> str: - return ",".join(map(str, sorted(userids))) - - @staticmethod - def _deserialize_userids(userids_ser: str) -> List[int]: - return list(map(int, userids_ser.split(","))) - - def _setup(self) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "CREATE TABLE IF NOT EXISTS pairings (guildid INTEGER, userids TEXT, channelid INTEGER, " - "threadid INTEGER, UNIQUE (guildid, userids, channelid, threadid))" - ) - self.con.commit() - - def insert( - self, guild_id: int, userids: List[int], channelid: int, threadid: int - ) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "INSERT INTO pairings VALUES (?, ?, ?, ?)", - (guild_id, PairingsDB._serialize_userids(userids), channelid, threadid), - ) - self.con.commit() - - def delete( - self, guild_id: int, userids: List[int], channelid: int, threadid: int - ) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "DELETE from pairings WHERE guildid=? and userids=? and channelid=? and threadid=?", - (guild_id, PairingsDB._serialize_userids(userids), channelid, threadid), - ) - self.con.commit() - - def query_userids( - self, guild_id: int, userids: List[int], channelid: int - ) -> Optional[int]: - with closing(self.con.cursor()) as cur: - res = cur.execute( - "SELECT threadid FROM pairings WHERE guildid=? and userids=? and channelid=?", - (guild_id, PairingsDB._serialize_userids(userids), channelid), - ) - results = res.fetchall() - if len(results) == 1: - return results[0][0] diff --git a/pairbot/main.py b/pairbot/main.py new file mode 100644 index 0000000..0175605 --- /dev/null +++ b/pairbot/main.py @@ -0,0 +1,810 @@ +import pdb + +from typing import ( + Any, + Optional, +) + +import random +from datetime import ( + date, + datetime, + timedelta, +) +import dateparser + +import discord +import discord.ext.tasks + +import sqlalchemy +import sqlalchemy.orm + +from . import config + +from .app import DiscordApp +from .models import ( + PairingChannel, + Schedule, + Skip, + Weekday, + Pairing, +) + +app = DiscordApp(intents=discord.Intents.all()) + +db_engine = sqlalchemy.create_engine(config.DATABASE_URL) +Session = sqlalchemy.orm.sessionmaker(db_engine) + + +# Utility functions + + +def get_active_pairing_channel( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, +) -> Optional[PairingChannel]: + if interaction.channel is None: return None + return ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel.id) + .one_or_none() + ) + + +def get_user_schedule( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, +) -> Optional[Schedule]: + if interaction.channel is None: return None + return ( + session.query(Schedule) + .filter(Schedule.channel_id == interaction.channel_id) + .filter(Schedule.user_id == interaction.user.id) + .one_or_none() + ) + + +def make_user_schedule( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, +) -> Schedule: + assert interaction.channel is not None + schedule = Schedule( + channel_id = interaction.channel_id, + user_id = interaction.user.id, + ) + session.add(schedule) + app.logger.info("Created new schedule.") + return schedule + + +def parse_human_readable_date(maybe_a_date: str) -> Optional[date]: + # Preprocess with some ad-hoc rules + maybe_a_date = ( + maybe_a_date + .lower() + .replace("next", "") + ) + parsed_date = dateparser.parse( + maybe_a_date, + settings={ + "PREFER_DATES_FROM": "future", + "RELATIVE_BASE": datetime.now(), + }, + languages=["en"], + ) + if parsed_date is not None: + return parsed_date.date() + + +def get_next_scheduled_date(schedule: Schedule) -> Optional[date]: + current_weekday = datetime.now().weekday() + for i in range(0, 7): + next_weekday = Weekday((current_weekday + i) % 7) + if schedule.is_available_on(next_weekday): + return datetime.now().date() + timedelta(days=i) + return None + + +def format_date(d: date): + return d.strftime('%A %B %d, %Y') + + +def get_skip( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, + d: date, +) -> Optional[Skip]: + return ( + session.query(Skip) + .filter(Skip.channel_id == interaction.channel_id) + .filter(Skip.user_id == interaction.user.id) + .filter(sqlalchemy.func.DATE(Skip.date) == d) + .one_or_none() + ) + + +def make_skip( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, + d: date, +) -> Skip: + skip = Skip( + channel_id=interaction.channel_id, + user_id=interaction.user.id, + date=d, + ) + session.add(skip) + app.logger.info("Created new skip.") + return skip + + +async def fail_on_inactive_channel( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, +) -> bool: + channel = get_active_pairing_channel(session, interaction) + if channel is None: + app.logger.info("Pairbot not active in channel.") + await interaction.response.send_message( + "Pairbot is not active in this channel.", + ephemeral=True + ) + return True + else: + return False + + +async def fail_on_existing_subscription( + interaction: discord.Interaction, + schedule: Schedule, + weekday: Optional[Weekday], +) -> bool: + assert isinstance(interaction.channel, discord.TextChannel) + if weekday is not None and schedule.is_available_on(weekday): + app.logger.info("Already subscribed on %s", weekday) + await interaction.response.send_message( + "You are already subscribed to pair programming on %s in #%s." % ( + weekday, + interaction.channel.name), + ephemeral=True + ) + return True + elif schedule.num_days_available() == 7: + app.logger.info("Already subscribed daily") + await interaction.response.send_message( + "You are already subscribed to daily pair programming in #%s." % interaction.channel.name, + ephemeral=True + ) + return True + return False + + +async def fail_on_nonexistent_subscription( + interaction: discord.Interaction, + schedule: Optional[Schedule], + weekday: Optional[Weekday], +) -> bool: + assert isinstance(interaction.channel, discord.TextChannel) + if schedule is not None and weekday is not None and not schedule.is_available_on(weekday): + app.logger.info("Not subscribed on %s", weekday) + await interaction.response.send_message( + "You are not subscribed to pair programming on %s in #%s." % ( + weekday, + interaction.channel.name), + ephemeral=True + ) + return True + elif schedule is None or schedule.num_days_available() == 0: + await interaction.response.send_message( + ("You are not subscribed to pair programming in #%s." % interaction.channel.name), + ephemeral=True + ) + return True + return False + + +def get_pairing( + session: sqlalchemy.orm.Session, + user_1: discord.User | discord.Member, + user_2: discord.User | discord.Member, + user_3: Optional[discord.User | discord.Member] +) -> Optional[Pairing]: + if user_3 is None: + return ( + session.query(Pairing) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_1.id, + Pairing.user_2_id == user_1.id, + )) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_2.id, + Pairing.user_2_id == user_2.id, + )) + .one_or_none() + ) + else: + return ( + session.query(Pairing) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_1.id, + Pairing.user_2_id == user_1.id, + Pairing.user_3_id == user_1.id, + )) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_2.id, + Pairing.user_2_id == user_2.id, + Pairing.user_3_id == user_2.id, + )) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_3.id, + Pairing.user_2_id == user_3.id, + Pairing.user_3_id == user_3.id, + )) + .one_or_none() + ) + + +# Slash commands + + +@app.command( + name="addpairbot", + description="Add Pairbot to the current channel.", +) +@discord.app_commands.guild_only() +@discord.app_commands.checks.has_permissions(administrator=True) +async def _add_pairbot(interaction: discord.Interaction): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() + ) + + if channel is not None: + app.logger.info("Pairbot is already added.") + await interaction.response.send_message( + "Pairbot is already added to \"#%s\"." % interaction.channel.name, + ephemeral=True, + ) + return + + channel = PairingChannel( + guild_id = interaction.guild.id, + channel_id = interaction.channel.id, + leetcode_integration = False, # TODO + ) + session.add(channel) + + app.logger.info("Added Pairbot.") + + await interaction.response.send_message( + "Added Pairbot to \"#%s\"." % interaction.channel.name, + ) + + +@app.command( + name="removepairbot", + description="Remove Pairbot from the current channel." +) +@discord.app_commands.guild_only() +@discord.app_commands.checks.has_permissions(administrator=True) +async def _remove_pairbot(interaction: discord.Interaction): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() + ) + + if channel is None: + app.logger.info("Pairbot is not added.") + await interaction.response.send_message( + "Pairbot is not added to \"#%s\"." % interaction.channel.name, + ephemeral=True, + ) + return + + session.delete(channel) + + app.logger.info("Removed pairbot.") + + await interaction.response.send_message( + "Removed Pairbot from \"#%s\"." % interaction.channel.name, + ) + + +@app.command( + name="subscribe", + description="Subscribe to pair programming (every day if no weekday specified)." +) +@discord.app_commands.guild_only() +async def _subscribe( + interaction: discord.Interaction, + weekday: Optional[Weekday], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if await fail_on_inactive_channel(session, interaction): + return + + schedule = get_user_schedule(session, interaction) + + if schedule is None: + schedule = make_user_schedule(session, interaction) + else: + app.logger.info("Found existing schedule") + + if await fail_on_existing_subscription(interaction, schedule, weekday): + return + + if weekday is not None: + schedule.set_availability_on(weekday, True) + app.logger.info("Subscribed to pair programming on %s", weekday) + await interaction.response.send_message( + "Successfully subscribed to pair programming on %s in #%s." % (weekday, interaction.channel.name), + ephemeral=True + ) + else: + schedule.set_availability_every_day(True) + app.logger.info("Subscribed to daily pair programming") + await interaction.response.send_message( + "Successfully subscribed to daily pair programming in #%s." % interaction.channel.name, + ephemeral=True + ) + + +@app.command( + name="unsubscribe", + description="Unsubscribe from pair programming (every day if no weekday specified)." +) +@discord.app_commands.guild_only() +async def _unsubscribe( + interaction: discord.Interaction, + weekday: Optional[Weekday], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if await fail_on_inactive_channel(session, interaction): + return + + schedule = get_user_schedule(session, interaction) + + if schedule is None or schedule.num_days_available() == 0: + await interaction.response.send_message( + ("You are already not subscribed to pair programming in #%s." % interaction.channel.name), + ephemeral=True + ) + return + + if await fail_on_nonexistent_subscription(interaction, schedule, weekday): + return + + if weekday is not None: + schedule.set_availability_on(weekday, False) + await interaction.response.send_message( + ("Successfully unsubscribed from pair programming on %s in #%s." % (weekday, interaction.channel.name)), + ephemeral=True + ) + else: + schedule.set_availability_every_day(False) + await interaction.response.send_message( + ("Successfully unsubscribed from all pair programming in #%s." % interaction.channel.name), + ephemeral=True + ) + + +@app.command( + name="skip", + description="Mark yourself as unavailable for pair programming on some date in the future)." +) +@discord.app_commands.describe(human_date="A human-readable date like \"tomorrow\" or \"January 1\".") +@discord.app_commands.guild_only() +async def _skip( + interaction: discord.Interaction, + human_date: Optional[str], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if await fail_on_inactive_channel(session, interaction): + return + + schedule = get_user_schedule(session, interaction) + + if await fail_on_nonexistent_subscription(interaction, schedule, None): + return + + assert schedule is not None + + if human_date is None: + skipped_date = get_next_scheduled_date(schedule) + assert skipped_date is not None + else: + skipped_date = parse_human_readable_date(human_date) + if skipped_date is None: + await interaction.response.send_message( + "Could not parse date \"%s\"." % date, + ephemeral=True + ) + return + + if skipped_date < datetime.now().date(): + await interaction.response.send_message( + "Cannot skip a date in the past: %s." % format_date(skipped_date), + ephemeral=True + ) + return + + skipped_weekday = Weekday(skipped_date.weekday()) + + if await fail_on_nonexistent_subscription(interaction, schedule, skipped_weekday): + return + + skip = get_skip(session, interaction, skipped_date) + + if skip is None: + skip = make_skip(session, interaction, skipped_date) + else: + await interaction.response.send_message( + "You already skipped pairing on %s." % format_date(skipped_date), + ephemeral=True + ) + return + + app.logger.info("Skipped pair programming on %s" % format_date(skipped_date)) + await interaction.response.send_message( + "Successfully skipped pair programming on %s" % format_date(skipped_date), + ephemeral=True + ) + + +@app.command( + name="unskip", + description="Unskip a skipped pairing session." +) +@discord.app_commands.describe(human_date="A human-readable date like \"tomorrow\" or \"January 1\".") +@discord.app_commands.guild_only() +async def _unskip( + interaction: discord.Interaction, + human_date: Optional[str], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if await fail_on_inactive_channel(session, interaction): + return + + schedule = get_user_schedule(session, interaction) + + if await fail_on_nonexistent_subscription(interaction, schedule, None): + return + assert schedule is not None + + if human_date is None: + unskipped_date = get_next_scheduled_date(schedule) + assert unskipped_date is not None + else: + unskipped_date = parse_human_readable_date(human_date) + if unskipped_date is None: + await interaction.response.send_message( + "Could not parse date \"%s\"." % date, + ephemeral=True + ) + return + + if unskipped_date < datetime.now().date(): + await interaction.response.send_message( + "Cannot unskip a date in the past: %s." % format_date(unskipped_date), + ephemeral=True + ) + return + + unskipped_weekday = Weekday(unskipped_date.weekday()) + + if await fail_on_nonexistent_subscription(interaction, schedule, unskipped_weekday): + return + + skip = get_skip(session, interaction, unskipped_date) + + if skip is None: + await interaction.response.send_message( + "You did not skip pair programming on %s." % format_date(unskipped_date), + ephemeral=True + ) + return + else: + session.delete(skip) + + app.logger.info("Unskipped pair programming on %s" % format_date(unskipped_date)) + await interaction.response.send_message( + "Successfully unskipped pair programming on %s" % format_date(unskipped_date), + ephemeral=True + ) + + +@app.command( + name="viewschedule", + description="View your pair programming schedule." +) +@discord.app_commands.guild_only() +async def _view_schedule( + interaction: discord.Interaction, +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if await fail_on_inactive_channel(session, interaction): + return + + all_schedules = ( + session.query(Schedule) + .filter(Schedule.user_id == interaction.user.id) + .all() + ) + + all_skips = ( + session.query(Skip) + .filter(Skip.user_id == interaction.user.id) + .filter(sqlalchemy.func.DATE(Skip.date) >= datetime.now()) + .all() + ) + + grouped_schedules: dict[str, list[str]] = dict() + grouped_skips: dict[str, list[str]] = dict() + + for schedule in all_schedules: + channel = interaction.guild.get_channel(schedule.channel_id) + if channel is None: continue + channel_name = channel.name + grouped_schedules[channel_name] = list(map(str, schedule.days_available())) + + for skip in all_skips: + channel = interaction.guild.get_channel(skip.channel_id) + if channel is None: continue + channel_name = channel.name + if channel_name not in grouped_skips: + grouped_skips[channel_name] = [] + grouped_skips[channel_name].append(skip.date.strftime("%a %b %d")) + + if len(all_schedules) == 0 or sum([len(days) for days in grouped_schedules.values()]) == 0: + await interaction.response.send_message( + "You are not subscribed to pair programming in any channel.", + ephemeral=True + ) + return + + if len(grouped_schedules) == 1: + channel_name, weekdays = grouped_schedules.popitem() + if len(weekdays) == 7: + await interaction.response.send_message( + "You are subscribed to pair programming in #%s every day%s" % ( + channel_name, + " (skipping " + ", ".join(grouped_skips[channel_name]) + ")" if channel_name in grouped_skips else "" + ), + ephemeral=True + ) + else: + await interaction.response.send_message( + "You are subscribed to pair programming in #%s on %s%s" % ( + channel_name, + ", ".join(weekdays), + " (skipping " + ", ".join(grouped_skips[channel_name]) + ")" if channel_name in grouped_skips else "" + ), + ephemeral=True + ) + return + + response = "You are subscribed to pair programming in the following channels:\n" + for channel_name, weekdays in grouped_schedules.items(): + response += "* #%s: %s%s\n" % ( + channel_name, + ", ".join(weekdays), + "skipping " + ", ".join(grouped_skips[channel_name]) if channel_name in grouped_skips else "" + ) + await interaction.response.send_message(response, ephemeral=True) + + + +@app.command( + name="pairwith", + description="Start a pairing session with another channel member." +) +@discord.app_commands.guild_only() +async def _pair_with( + interaction: discord.Interaction, + user: discord.Member, +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + if user.id == interaction.user.id: + await interaction.response.send_message( + f"You cannot pair with yourself.", + ephemeral=True + ) + return + + with Session.begin() as session: + if await fail_on_inactive_channel(session, interaction): + return + + pairing = get_pairing(session, interaction.user, user, None) + + usernames = sorted([user.name for user in (interaction.user, user)]) + + if pairing is None: + # Make new discord thread + thread = await interaction.channel.create_thread( + name=f"{usernames[0]} & {usernames[1]}", + ) + + pairing = Pairing( + thread_id=thread.id, + channel_id=interaction.channel_id, + user_1_id=interaction.user.id, + user_2_id=user.id, + ) + session.add(pairing) + else: + thread = interaction.guild.get_thread(pairing.thread_id) + if thread is None: + session.delete(pairing) + + # Make new discord thread + thread = await interaction.channel.create_thread( + name=f"{usernames[0]} & {usernames[1]}", + ) + + pairing = Pairing( + thread_id=thread.id, + channel_id=interaction.channel_id, + user_1_id=interaction.user.id, + user_2_id=user.id, + ) + session.add(pairing) + + await thread.send( + f"<@{interaction.user.id}> has started a pairing session with you, <@{user.id}>. Happy pairing! :computer:" + ) + await interaction.response.send_message( + f"Successfully created pairing thread with <@{user.id}>", + ephemeral=True + ) + + +@app.command( + name="makegroups", + description="Assign random groups to people subscribed to pairing today." +) +@discord.app_commands.guild_only() +@discord.app_commands.checks.has_permissions(administrator=True) +async def _make_groups( + interaction: discord.Interaction[Any] +): + await make_groups() + await interaction.response.send_message( + f"Great success!! :-)", + ephemeral=True + ) + + +@discord.ext.tasks.loop(time=config.PAIRING_TIME) +async def make_groups(): + app.logger.info("Creating random groups.") + + with Session.begin() as session: + for guild in app.client.guilds: + channels = ( + session.query(PairingChannel) + .filter(PairingChannel.guild_id == guild.id) + .all() + ) + for pairing_channel in channels: + channel = app.client.get_channel(pairing_channel.channel_id) + if channel is None: + session.delete(pairing_channel) + app.logger.info("Pairing channel is gone :(") + continue + assert isinstance(channel, discord.TextChannel) + schedules = ( + session.query(Schedule) + .filter(Schedule.channel_id == channel.id) + .all() + ) + + todays_users = [ + guild.get_member(schedule.user_id) + for schedule in schedules + if schedule.is_available_on(Weekday(datetime.now().weekday())) + ] + + if len(todays_users) < 2: + app.logger.info("Not enough users to make groups.") + + random.shuffle(todays_users) + + groups = [] + if len(todays_users) % 2 == 1: + pair = ( + todays_users.pop(), + todays_users.pop(), + todays_users.pop(), + ) + groups.append(pair) + + while len(todays_users) > 0: + pair = ( + todays_users.pop(), + todays_users.pop(), + None, + ) + groups.append(pair) + + for group in groups: + usernames = sorted([user.name for user in group if user is not None]) + user1, user2, user3 = group + + pairing = get_pairing(session, user1, user2, user3) + + if pairing is None: + # Make new discord thread + thread = await channel.create_thread( + name=" & ".join(usernames) + ) + + pairing = Pairing( + thread_id=thread.id, + channel_id=channel.id, + user_1_id=user1.id, + user_2_id=user2.id, + user_3_id=user3.id if user3 is not None else None, + ) + session.add(pairing) + else: + thread = guild.get_thread(pairing.thread_id) + if thread is None: + # Thread deleted somehow. Start over + session.delete(pairing) + + # Make new discord thread + thread = await channel.create_thread( + name=" & ".join(usernames) + ) + + pairing = Pairing( + thread_id=thread.id, + channel_id=channel.id, + user_1_id=user1.id, + user_2_id=user2.id, + user_3_id=user3.id if user3 is not None else None, + ) + session.add(pairing) + + msg = ", ".join(("<@%d>" % user.id) for user in group if user is not None) + msg += ": " + msg += "You have been matched together. Happy pairing! :computer:" + + await thread.send(msg) + + app.logger.info( + "Made pairing group with %s" % " & ".join(usernames) + ) + + +def run(): + app.run(config.DISCORD_BOT_TOKEN) diff --git a/pairbot/models.py b/pairbot/models.py new file mode 100644 index 0000000..e7794c0 --- /dev/null +++ b/pairbot/models.py @@ -0,0 +1,144 @@ +"""models.py + +This module contains Peewee ORM models and their respective logic. +""" +import calendar +import enum +from datetime import datetime + +from typing import ( + Optional, + Union, +) + +from sqlalchemy import ( + PrimaryKeyConstraint, + Boolean, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, +) + + +class Base(DeclarativeBase): + pass + + +class PairingChannel(Base): + """A Discord channel in which Pairbot has been allowed to operate.""" + __tablename__ = "channel" + + guild_id: Mapped[int] + """The Discord guild ID.""" + + channel_id: Mapped[int] + """The Discord channel ID.""" + + leetcode_integration: Mapped[bool] + """Whether leetcode integration is active in the channel or not.""" + + __table_args__ = ( + PrimaryKeyConstraint("guild_id", "channel_id",), + ) + + +class Weekday(enum.IntEnum): + """An enum representing the day of the week (0-6 starting from Monday).""" + Monday = 0 + Tuesday = 1 + Wednesday = 2 + Thursday = 3 + Friday = 4 + Saturday = 5 + Sunday = 6 + + def __str__(self): + return calendar.day_name[self] + + +class Schedule(Base): + """Represents the availability of a user in a channel on a given day of the week""" + __tablename__ = "schedule" + + channel_id: Mapped[int] + """The Discord channel ID.""" + + user_id: Mapped[int] + """The Discord user ID.""" + + available_0: Mapped[bool] = mapped_column(Boolean, default=False) + available_1: Mapped[bool] = mapped_column(Boolean, default=False) + available_2: Mapped[bool] = mapped_column(Boolean, default=False) + available_3: Mapped[bool] = mapped_column(Boolean, default=False) + available_4: Mapped[bool] = mapped_column(Boolean, default=False) + available_5: Mapped[bool] = mapped_column(Boolean, default=False) + available_6: Mapped[bool] = mapped_column(Boolean, default=False) + + __table_args__ = ( + PrimaryKeyConstraint("channel_id", "user_id"), + ) + + # Utility methods + + def is_available_on(self, day_of_week: Weekday) -> bool: + return getattr(self, f"available_{int(day_of_week)}") + + def set_availability_on(self, day_of_week: Weekday, value: bool): + setattr(self, f"available_{int(day_of_week)}", value) + + def set_availability_every_day(self, value: bool): + for day in Weekday: + self.set_availability_on(day, value) + + def days_available(self) -> list[Weekday]: + return [day for day in Weekday if self.is_available_on(day)] + + def days_unavailable(self) -> list[Weekday]: + return [day for day in Weekday if not self.is_available_on(day)] + + def num_days_available(self) -> int: + return sum(1 for day in Weekday if self.is_available_on(day)) + + +class Skip(Base): + """Represents an adjustment to a user's availability on a specific date.""" + __tablename__ = "skip" + + channel_id: Mapped[int] + """The Discord channel ID.""" + + user_id: Mapped[int] + """The Discord user ID.""" + + date: Mapped[datetime] + """The date on which the user's availability is set.""" + + __table_args__ = ( + PrimaryKeyConstraint("channel_id", "user_id", "date"), + ) + + +class Pairing(Base): + """Represents a Discord thread created by Pairbot.""" + __tablename__ = "pairing" + + channel_id: Mapped[int] + """The Discord channel ID.""" + + thread_id: Mapped[int] + """The Discord thread ID""" + + user_1_id: Mapped[int] + """The first Discord user ID.""" + + user_2_id: Mapped[int] + """The second Discord user ID.""" + + user_3_id: Mapped[Optional[int]] + """The third Discord user ID.""" + + __table_args__ = ( + PrimaryKeyConstraint("channel_id", "thread_id"), + ) diff --git a/pairbot/utils.py b/pairbot/utils.py deleted file mode 100644 index fb67c83..0000000 --- a/pairbot/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -import argparse -import json - -import discord - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-d", "--dev", action="store_true", help="use dev environment") - return parser.parse_args() - - -def read_guild_to_channel(path: str): - with open(path, "r") as f: - return json.load(f) - - -def get_user_name(user: discord.User): - if user.global_name is not None: - return user.global_name - elif user.nick is not None: - return user.nick - else: - return user.name diff --git a/pyproject.toml b/pyproject.toml index ef28a21..cd7e039 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,11 @@ name = "pairbot" version = "0.0.1" readme = "README.md" dependencies = [ + "SQLAlchemy", + "alembic", "discord", - "python-dotenv" + "python-dotenv", + "dateparser" ] [project.urls] @@ -21,6 +24,7 @@ pairbot = "pairbot:run" # Testing [tool.hatch.envs.default] +python = "3.11" dependencies = [ "coverage[toml]>=6.5", "pytest",