From b5db7618101bbc3d3c735ea1b6ea7bb0bc54c3b5 Mon Sep 17 00:00:00 2001 From: Yaxel Date: Sat, 25 Nov 2023 14:06:20 -0800 Subject: [PATCH 1/3] Rewrite :v --- README.md | 7 + alembic.ini | 116 +++ alembic/README | 1 + alembic/env.py | 82 ++ alembic/script.py.mako | 26 + .../versions/c2ab7b1ce2f4_default_values.py | 30 + .../d588346e0112_nuked_migrations_v.py | 64 ++ pairbot/client.old.py | 384 +++++++ pairbot/client.py | 962 ++++++++++++------ pairbot/config.py | 67 ++ pairbot/models.py | 129 +++ pairbot/utils.py | 1 + pyproject.toml | 6 +- 13 files changed, 1541 insertions(+), 334 deletions(-) create mode 100644 alembic.ini create mode 100644 alembic/README create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/c2ab7b1ce2f4_default_values.py create mode 100644 alembic/versions/d588346e0112_nuked_migrations_v.py create mode 100644 pairbot/client.old.py create mode 100644 pairbot/config.py create mode 100644 pairbot/models.py diff --git a/README.md b/README.md index 7ad106b..b33b7f9 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,13 @@ hatch env create hatch run dev ``` +Database migrations are handled with [Alembic](https://alembic.sqlalchemy.org). + +``` sh +alembic revision --autogenerate -m "Brief description of changes..." +alembic upgrade head +``` + Run [black](https://pypi.org/project/black/) and [isort](https://pycqa.github.io/isort/index.html) on your code: ``` sh diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..1e8c258 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,116 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..5cdf272 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,82 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +from pairbot import config as pairbot_config +from pairbot import models + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config +config.set_main_option("sqlalchemy.url", pairbot_config.DATABASE_URL) + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = models.Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/c2ab7b1ce2f4_default_values.py b/alembic/versions/c2ab7b1ce2f4_default_values.py new file mode 100644 index 0000000..1b47059 --- /dev/null +++ b/alembic/versions/c2ab7b1ce2f4_default_values.py @@ -0,0 +1,30 @@ +"""default values + +Revision ID: c2ab7b1ce2f4 +Revises: d588346e0112 +Create Date: 2023-11-25 00:16:18.559944 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'c2ab7b1ce2f4' +down_revision: Union[str, None] = 'd588346e0112' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/alembic/versions/d588346e0112_nuked_migrations_v.py b/alembic/versions/d588346e0112_nuked_migrations_v.py new file mode 100644 index 0000000..627ff9c --- /dev/null +++ b/alembic/versions/d588346e0112_nuked_migrations_v.py @@ -0,0 +1,64 @@ +"""Nuked migrations :v + +Revision ID: d588346e0112 +Revises: +Create Date: 2023-11-25 00:12:31.330617 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'd588346e0112' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('channel', + sa.Column('guild_id', sa.Integer(), nullable=False), + sa.Column('channel_id', sa.Integer(), nullable=False), + sa.Column('active', sa.Boolean(), nullable=False), + sa.Column('leetcode_integration', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('guild_id', 'channel_id') + ) + op.create_table('pairing', + sa.Column('channel_id', sa.Integer(), nullable=False), + sa.Column('thread_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('channel_id', 'thread_id', 'user_id') + ) + op.create_table('schedule', + sa.Column('channel_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('available_Monday', sa.Boolean(), nullable=False), + sa.Column('available_Tuesday', sa.Boolean(), nullable=False), + sa.Column('available_Wednesday', sa.Boolean(), nullable=False), + sa.Column('available_Thursday', sa.Boolean(), nullable=False), + sa.Column('available_Friday', sa.Boolean(), nullable=False), + sa.Column('available_Saturday', sa.Boolean(), nullable=False), + sa.Column('available_Sunday', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('channel_id', 'user_id') + ) + op.create_table('schedule_adjustment', + sa.Column('channel_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('date', sa.DateTime(), nullable=False), + sa.Column('available', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('channel_id', 'user_id', 'date') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('schedule_adjustment') + op.drop_table('schedule') + op.drop_table('pairing') + op.drop_table('channel') + # ### end Alembic commands ### diff --git a/pairbot/client.old.py b/pairbot/client.old.py new file mode 100644 index 0000000..7b77418 --- /dev/null +++ b/pairbot/client.old.py @@ -0,0 +1,384 @@ +import json +import logging +import os +import random +import sqlite3 +from datetime import datetime +from pathlib import Path +from typing import List + +import discord +from discord import app_commands +from discord.ext import tasks +from dotenv import load_dotenv + +from .db import PairingsDB, ScheduleDB, Timeblock +from .utils import get_user_name, parse_args, read_guild_to_channel + +load_dotenv() +args = parse_args() +if args.dev: + print("Running in dev mode.") + BOT_TOKEN = os.getenv("BOT_TOKEN_DEV") + DATA_DIR = "data" + GUILDS_PATH = f"{DATA_DIR}/guilds-dev.json" + SCHEDULE_DB_PATH = f"{DATA_DIR}/schedule-dev.db" + PAIRINGS_DB_PATH = f"{DATA_DIR}/pairings-dev.db" + LOG_FILE = "pairbot-dev.log" +else: + print("Running in prod mode.") + BOT_TOKEN = os.getenv("BOT_TOKEN") + DATA_DIR = "data" + GUILDS_PATH = f"{DATA_DIR}/guilds.json" + SCHEDULE_DB_PATH = f"{DATA_DIR}/schedule.db" + PAIRINGS_DB_PATH = f"{DATA_DIR}/pairings.db" + LOG_FILE = "pairbot.log" +SORRY = "Unexpected error." +Path(DATA_DIR).mkdir(parents=True, exist_ok=True) + +logging.basicConfig( + filename=LOG_FILE, + filemode="a", + format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + level=logging.DEBUG, +) +logger = logging.getLogger("pairbot") + +intents = discord.Intents.default() +intents.members = True +intents.message_content = True +client = discord.Client(intents=intents) +tree = app_commands.CommandTree(client) +db = ScheduleDB(SCHEDULE_DB_PATH) +pairings_db = PairingsDB(PAIRINGS_DB_PATH) + + +# Discord API currently doesn't support variadic arguments +# https://github.com/discord/discord-api-docs/discussions/3286 +@tree.command( + name="subscribe", + description="Add timeblocks to find a partner for pair programming. \ +Matches go out at 8am UTC that day.", +) +@app_commands.describe( + timeblock="Choose WEEK to get a partner for the whole week (pairs announced Monday UTC)." +) +@app_commands.choices( + timeblock=[ + app_commands.Choice(name=Timeblock.WEEK.name, value=Timeblock.WEEK.value), + app_commands.Choice(name=Timeblock.Monday.name, value=Timeblock.Monday.value), + app_commands.Choice(name=Timeblock.Tuesday.name, value=Timeblock.Tuesday.value), + app_commands.Choice( + name=Timeblock.Wednesday.name, value=Timeblock.Wednesday.value + ), + app_commands.Choice( + name=Timeblock.Thursday.name, value=Timeblock.Thursday.value + ), + app_commands.Choice(name=Timeblock.Friday.name, value=Timeblock.Friday.value), + app_commands.Choice( + name=Timeblock.Saturday.name, value=Timeblock.Saturday.value + ), + app_commands.Choice(name=Timeblock.Sunday.name, value=Timeblock.Sunday.value), + ] +) +async def _subscribe(interaction: discord.Interaction, timeblock: Timeblock): + try: + db.insert(interaction.guild_id, interaction.user.id, timeblock) + timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) + logger.info( + f"G:{interaction.guild_id} U:{interaction.user.id} subscribed T:{timeblock.name}." + ) + msg = ( + f"Your new schedule is `{Timeblock.generate_schedule(timeblocks)}`. " + f"You can call `/subscribe` again to sign up for more days." + ) + await interaction.response.send_message(msg, ephemeral=True) + except sqlite3.IntegrityError as e: + logger.info( + f"G:{interaction.guild_id} U:{interaction.user.id} failed subscribe T:{timeblock.name}." + ) + logger.warning(e, exc_info=True) + msg = ( + f"You are already subscribed to {timeblock}. " + f"Call `/unsubscribe` to remove a subscription or `/schedule` to view your schedule." + ) + await interaction.response.send_message(msg, ephemeral=True) + except Exception as e: + logger.error(e, exc_info=True) + await interaction.response.send_message(SORRY, ephemeral=True) + + +@tree.command(name="unsubscribe", description="Remove timeblocks for pair programming.") +@app_commands.describe(timeblock="Call `/unsubscribe-all` to remove all timeblocks.") +@app_commands.choices( + timeblock=[ + app_commands.Choice(name=Timeblock.WEEK.name, value=Timeblock.WEEK.value), + app_commands.Choice(name=Timeblock.Monday.name, value=Timeblock.Monday.value), + app_commands.Choice(name=Timeblock.Tuesday.name, value=Timeblock.Tuesday.value), + app_commands.Choice( + name=Timeblock.Wednesday.name, value=Timeblock.Wednesday.value + ), + app_commands.Choice( + name=Timeblock.Thursday.name, value=Timeblock.Thursday.value + ), + app_commands.Choice(name=Timeblock.Friday.name, value=Timeblock.Friday.value), + app_commands.Choice( + name=Timeblock.Saturday.name, value=Timeblock.Saturday.value + ), + app_commands.Choice(name=Timeblock.Sunday.name, value=Timeblock.Sunday.value), + ] +) +async def _unsubscribe(interaction: discord.Interaction, timeblock: Timeblock): + try: + db.delete(interaction.guild_id, interaction.user.id, timeblock) + timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) + logger.info( + f"G:{interaction.guild_id} U:{interaction.user.id} unsubscribed T:{timeblock.name}." + ) + msg = f"Your new schedule is `{Timeblock.generate_schedule(timeblocks)}`." + await interaction.response.send_message(msg, ephemeral=True) + except Exception as e: + logger.error(e, exc_info=True) + await interaction.response.send_message(SORRY, ephemeral=True) + + +@tree.command( + name="unsubscribe-all", description="Remove all timeblocks for pair programming." +) +async def _unsubscribe_all(interaction: discord.Interaction): + try: + db.unsubscribe(interaction.guild_id, interaction.user.id) + logger.info( + f"G:{interaction.guild_id} U:{interaction.user.id} called unsubscribe-all." + ) + msg = "Your pairing subscriptions have been removed. To rejoin, call `/subscribe` again." + await interaction.response.send_message(msg, ephemeral=True) + except Exception as e: + logger.error(e, exc_info=True) + await interaction.response.send_message(SORRY, ephemeral=True) + + +@tree.command(name="schedule", description="View your pairing schedule.") +async def _schedule(interaction: discord.Interaction): + try: + timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) + schedule = Timeblock.generate_schedule(timeblocks) + logger.info( + f"G:{interaction.guild_id} U:{interaction.user.id} queried schedule {schedule}." + ) + msg = ( + f"Your current schedule is `{schedule}`. " + "You can call `/subscribe` or `/unsubscribe` to modify it." + ) + await interaction.response.send_message(msg, ephemeral=True) + except Exception as e: + logger.error(e, exc_info=True) + await interaction.response.send_message(SORRY, ephemeral=True) + + +@tree.command( + name="set-channel", description="Set a channel for bot messages (admin only)." +) +@app_commands.checks.has_permissions(administrator=True) +async def _set_channel(interaction: discord.Interaction, channel: discord.TextChannel): + try: + guild_to_channel = read_guild_to_channel(GUILDS_PATH) + guild_to_channel[str(interaction.guild_id)] = channel.id + with open(GUILDS_PATH, "w") as f: + json.dump(guild_to_channel, f) + logger.info( + f"G:{interaction.guild_id} U:{interaction.user.id} set-channel C:{channel.id}." + ) + msg = f"Successfully set bot channel to `{channel.name}`." + await interaction.response.send_message(msg, ephemeral=True) + except Exception as e: + logger.error(e, exc_info=True) + await interaction.response.send_message(SORRY, ephemeral=True) + + +@tree.command( + name="pairwith", + description="Start an immediate pairing session with another member.", +) +async def _pairwith(interaction: discord.Interaction, user: discord.Member): + try: + guild_to_channel = read_guild_to_channel(GUILDS_PATH) + channel_id = guild_to_channel[str(interaction.guild_id)] + channel = client.get_channel(channel_id) + users = [interaction.user, user] + notify_msg = ( + f"<@{interaction.user.id}> has started an on-demand pair with you, <@{user.id}>. " + "Happy pairing! :computer:" + ) + await create_group_thread(interaction.guild_id, users, channel, notify_msg) + logger.info( + f"G:{interaction.guild_id} C:{channel.id} on-demand paired U:{interaction.user.id} with {user.id}." + ) + await interaction.response.send_message( + f"Thread with {get_user_name(user)} created in channel `{channel.name}`.", + ephemeral=True, + ) + except Exception as e: + logger.error(e, exc_info=True) + await interaction.response.send_message(SORRY, ephemeral=True) + + +async def dm_user(user: discord.User, msg: str): + try: + channel = await user.create_dm() + await channel.send(msg) + except Exception as e: + logger.error(e, exc_info=True) + + +async def create_group_thread( + guild_id: int, + users: List[discord.User], + channel: discord.TextChannel, + notify_msg: str, +): + # @ notifying users in a private thread invites them + # so `notify_msg` must notify for this to work + userids = [user.id for user in users] + thread_id = pairings_db.query_userids(guild_id, userids, channel.id) + thread = None + if thread_id is not None: + logger.debug(f"Found existing thread {thread_id} for G:{guild_id} U:{userids}") + try: + guild = client.get_guild(guild_id) + thread = await guild.fetch_channel(thread_id) + except discord.errors.NotFound: + logger.debug(f"Couldn't fetch thread {thread_id}, maybe deleted?") + pairings_db.delete(guild_id, userids, channel.id, thread_id) + if thread is None: + title = ", ".join(get_user_name(user) for user in users) + thread = await channel.create_thread( + name=f"{title}", auto_archive_duration=10080 + ) + logger.debug(f"Created new thread {thread.id} for G:{guild_id} U:{userids}") + pairings_db.insert(guild_id, userids, channel.id, thread.id) + else: + logger.debug(f"Found existing thread {thread_id} for G:{guild_id} U:{userids}") + guild = client.get_guild(guild_id) + thread = await guild.fetch_channel(thread_id) + await thread.send(notify_msg) + + +async def on_tree_error( + interaction: discord.Interaction, error: app_commands.AppCommandError +): + if isinstance(error, app_commands.CommandOnCooldown): + return await interaction.response.send_message( + f"Command is currently on cooldown, try again in {error.retry_after:.2f} seconds.", + ephemeral=True, + ) + elif isinstance(error, app_commands.MissingPermissions): + return await interaction.response.send_message( + "You don't have the permissions to do that.", ephemeral=True + ) + else: + raise error + + +@tasks.loop(hours=1) +async def pairing_cron(): + def should_run(): + now = datetime.utcnow() + hour = now.time().hour + logger.debug(f"Checking pairing job at UTC:{now}.") + return hour == 8 + + if should_run(): + await run_pairing() + + +async def run_pairing(): + now = datetime.utcnow() + print(now) + logger.debug(f"Running pairing job at UTC:{now}.") + weekday = now.weekday() + weekday_map = { + 0: Timeblock.Monday, + 1: Timeblock.Tuesday, + 2: Timeblock.Wednesday, + 3: Timeblock.Thursday, + 4: Timeblock.Friday, + 5: Timeblock.Saturday, + 6: Timeblock.Sunday, + } + timeblock = weekday_map[weekday] + for guild in client.guilds: + await pair(guild.id, timeblock) + # weekly Monday match + if weekday == 0: + await pair(guild.id, Timeblock.WEEK) + + +async def pair(guild_id: int, timeblock: Timeblock): + try: + userids = db.query_timeblock(guild_id, timeblock) + users = [client.get_user(userid) for userid in userids] + # Users may leave the server without unsubscribing + # TODO: listen to that event and drop them from the table + users = list(filter(None, users)) + logger.info( + f"Pairing for G:{guild_id} T:{timeblock.name} with {len(users)}/{len(userids)} users." + ) + guild_to_channel = read_guild_to_channel(GUILDS_PATH) + channel = client.get_channel(guild_to_channel[str(guild_id)]) + if len(users) < 2: + for user in users: + logger.info( + f"G:{guild_id} T:{timeblock.name} pair failed, dming U:{user.id}." + ) + msg = ( + f"Thanks for signing up for pairing this {timeblock}. " + "Unfortunately, there was nobody else available this time." + ) + await dm_user(user, msg) + await channel.send( + f"Not enough signups this {timeblock}. Try `/subscribe` to sign up!" + ) + return + + random.shuffle(users) + groups = [users[i :: len(users) // 2] for i in range(len(users) // 2)] + for group in groups: + notify_msg = ", ".join(f"<@{user.id}>" for user in group) + notify_msg = f"{notify_msg}: you've been matched together for this {timeblock}. Happy pairing! :computer:" + await create_group_thread(guild_id, group, channel, notify_msg) + logger.info( + f"G:{guild_id} C:{channel.id} paired U:{[user.id for user in group]}." + ) + await channel.send( + f"Pairings for {len(users)} users have been sent out for this {timeblock}. Try `/subscribe` to sign up!" + ) + except Exception as e: + logger.error(e, exc_info=True) + + +def local_setup(): + try: + read_guild_to_channel(GUILDS_PATH) + except Exception: + with open(GUILDS_PATH, "w") as f: + json.dump({}, f) + + +@client.event +async def on_ready(): + local_setup() + await client.wait_until_ready() + tree.on_error = on_tree_error + for guild in client.guilds: + tree.copy_global_to(guild=guild) + await tree.sync(guild=guild) + print("Code sync complete!") + pairing_cron.start() + print("Starting cron loop...") + logger.info("Bot started.") + + +def run(): + client.run(BOT_TOKEN) diff --git a/pairbot/client.py b/pairbot/client.py index 7b77418..49c1fdc 100644 --- a/pairbot/client.py +++ b/pairbot/client.py @@ -1,384 +1,680 @@ -import json +"""client.py + +This module describes the Discord slash command interface and corresponding logic. +""" + +import functools import logging -import os -import random -import sqlite3 -from datetime import datetime -from pathlib import Path -from typing import List -import discord -from discord import app_commands -from discord.ext import tasks -from dotenv import load_dotenv - -from .db import PairingsDB, ScheduleDB, Timeblock -from .utils import get_user_name, parse_args, read_guild_to_channel - -load_dotenv() -args = parse_args() -if args.dev: - print("Running in dev mode.") - BOT_TOKEN = os.getenv("BOT_TOKEN_DEV") - DATA_DIR = "data" - GUILDS_PATH = f"{DATA_DIR}/guilds-dev.json" - SCHEDULE_DB_PATH = f"{DATA_DIR}/schedule-dev.db" - PAIRINGS_DB_PATH = f"{DATA_DIR}/pairings-dev.db" - LOG_FILE = "pairbot-dev.log" -else: - print("Running in prod mode.") - BOT_TOKEN = os.getenv("BOT_TOKEN") - DATA_DIR = "data" - GUILDS_PATH = f"{DATA_DIR}/guilds.json" - SCHEDULE_DB_PATH = f"{DATA_DIR}/schedule.db" - PAIRINGS_DB_PATH = f"{DATA_DIR}/pairings.db" - LOG_FILE = "pairbot.log" -SORRY = "Unexpected error." -Path(DATA_DIR).mkdir(parents=True, exist_ok=True) - -logging.basicConfig( - filename=LOG_FILE, - filemode="a", - format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", - datefmt="%H:%M:%S", - level=logging.DEBUG, +from typing import ( + Any, + Optional, + TypeVar, + ParamSpec, + Callable, + Concatenate, + Coroutine, ) -logger = logging.getLogger("pairbot") -intents = discord.Intents.default() -intents.members = True -intents.message_content = True -client = discord.Client(intents=intents) -tree = app_commands.CommandTree(client) -db = ScheduleDB(SCHEDULE_DB_PATH) -pairings_db = PairingsDB(PAIRINGS_DB_PATH) +import sqlalchemy +from sqlalchemy import ( + create_engine, + func +) +import sqlalchemy.orm +from sqlalchemy.orm import ( + Session, + sessionmaker, +) -# Discord API currently doesn't support variadic arguments -# https://github.com/discord/discord-api-docs/discussions/3286 -@tree.command( - name="subscribe", - description="Add timeblocks to find a partner for pair programming. \ -Matches go out at 8am UTC that day.", +import discord +import discord.ext.commands +import discord.ext.tasks + +from datetime import datetime, timedelta +import dateparser + +from . import ( + config, ) -@app_commands.describe( - timeblock="Choose WEEK to get a partner for the whole week (pairs announced Monday UTC)." + +from .models import ( + Weekday, + PairingChannel, + Schedule, + ScheduleAdjustment, + Thread, ) -@app_commands.choices( - timeblock=[ - app_commands.Choice(name=Timeblock.WEEK.name, value=Timeblock.WEEK.value), - app_commands.Choice(name=Timeblock.Monday.name, value=Timeblock.Monday.value), - app_commands.Choice(name=Timeblock.Tuesday.name, value=Timeblock.Tuesday.value), - app_commands.Choice( - name=Timeblock.Wednesday.name, value=Timeblock.Wednesday.value - ), - app_commands.Choice( - name=Timeblock.Thursday.name, value=Timeblock.Thursday.value - ), - app_commands.Choice(name=Timeblock.Friday.name, value=Timeblock.Friday.value), - app_commands.Choice( - name=Timeblock.Saturday.name, value=Timeblock.Saturday.value - ), - app_commands.Choice(name=Timeblock.Sunday.name, value=Timeblock.Sunday.value), - ] + +logger = logging.getLogger(__name__) + +# type fuckery for the command decorator +T = TypeVar("T") +P = ParamSpec("P") +CommandCallback = Callable[Concatenate[discord.Interaction[Any], P], Coroutine[Any, Any, T]] + +class Pairbot(discord.Client): + """Represents a running instance of Pairbot.""" + + def __init__(self, intents: discord.Intents, **options: Any) -> None: + super().__init__(intents=intents, **options) + self.tree = discord.app_commands.CommandTree(self) + + self.db_engine = create_engine(config.DATABASE_URL) + self.make_orm_session = sessionmaker(self.db_engine) + + def command( + self, + **options: Any + ): + """Wrapper for pairbot slash commands with logging and error-handling.""" + def decorator(callback: CommandCallback): + @self.tree.command(**options) + @discord.app_commands.guild_only() + @functools.wraps(callback) + async def wrapper( + interaction: discord.Interaction[Any], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + # Keep type checker happy + assert interaction.command is not None + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + assert isinstance(interaction.user, discord.Member) + + # Log command execution + pretty_kwargs = ( + " with arguments { " + + ", ".join((f"{key}=\"{str(value)}\"" for key, value in kwargs.items())) + + " }" + if len(kwargs) > 0 else " with no arguments" + ) + logger.info( + f"User \"{interaction.user.name}\" executed command /{interaction.command.name}{pretty_kwargs} in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." + ) + + try: + await callback(interaction, *args, **kwargs) + except Exception as e: + logger.error(e, exc_info=True) + await interaction.response.send_message("Pairbot broke somehow! :v", ephemeral=True) + + return wrapper + return decorator + + async def on_ready(self) -> None: + for guild in self.guilds: + logger.info(f"Copying command tree to guild \"{guild.name}\"") + self.tree.copy_global_to(guild=guild) + await self.tree.sync(guild=guild) + logger.info("Pairbot ready.") + + +# Instantiate client and register slash commands +intents = discord.Intents.all() +client = Pairbot( + intents=intents, ) -async def _subscribe(interaction: discord.Interaction, timeblock: Timeblock): - try: - db.insert(interaction.guild_id, interaction.user.id, timeblock) - timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} subscribed T:{timeblock.name}." - ) - msg = ( - f"Your new schedule is `{Timeblock.generate_schedule(timeblocks)}`. " - f"You can call `/subscribe` again to sign up for more days." + + +@client.command( + name="addpairbot", + description="Add Pairbot to the current channel." +) +@discord.app_commands.checks.has_permissions(administrator=True) +async def _add_pairbot(interaction: discord.Interaction): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + session = client.make_orm_session() + with session.begin(): + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() ) - await interaction.response.send_message(msg, ephemeral=True) - except sqlite3.IntegrityError as e: + + if channel is not None: + if channel.active: + logger.info( + f"Pairbot is already added to guild \"{interaction.guild.name}\", channel \"#{interaction.channel}\"." + ) + await interaction.response.send_message( + f"Pairbot is already added to \"#{interaction.channel.name}\"." + ) + return + else: + channel.active = True + else: + channel = PairingChannel( + guild_id = interaction.guild.id, + channel_id = interaction.channel.id, + active = True, + leetcode_integration = False, # TODO + ) + session.add(channel) + + session.commit() + logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} failed subscribe T:{timeblock.name}." + f"Added Pairbot to guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." ) - logger.warning(e, exc_info=True) - msg = ( - f"You are already subscribed to {timeblock}. " - f"Call `/unsubscribe` to remove a subscription or `/schedule` to view your schedule." + await interaction.response.send_message( + f"Added Pairbot to \"#{interaction.channel.name}\"." ) - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command(name="unsubscribe", description="Remove timeblocks for pair programming.") -@app_commands.describe(timeblock="Call `/unsubscribe-all` to remove all timeblocks.") -@app_commands.choices( - timeblock=[ - app_commands.Choice(name=Timeblock.WEEK.name, value=Timeblock.WEEK.value), - app_commands.Choice(name=Timeblock.Monday.name, value=Timeblock.Monday.value), - app_commands.Choice(name=Timeblock.Tuesday.name, value=Timeblock.Tuesday.value), - app_commands.Choice( - name=Timeblock.Wednesday.name, value=Timeblock.Wednesday.value - ), - app_commands.Choice( - name=Timeblock.Thursday.name, value=Timeblock.Thursday.value - ), - app_commands.Choice(name=Timeblock.Friday.name, value=Timeblock.Friday.value), - app_commands.Choice( - name=Timeblock.Saturday.name, value=Timeblock.Saturday.value - ), - app_commands.Choice(name=Timeblock.Sunday.name, value=Timeblock.Sunday.value), - ] + + +@client.command( + name="removepairbot", + description="Remove Pairbot from the current channel." ) -async def _unsubscribe(interaction: discord.Interaction, timeblock: Timeblock): - try: - db.delete(interaction.guild_id, interaction.user.id, timeblock) - timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} unsubscribed T:{timeblock.name}." +@discord.app_commands.guild_only() +@discord.app_commands.checks.has_permissions(administrator=True) +async def _remove_pairbot(interaction: discord.Interaction): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + session = client.make_orm_session() + with session.begin(): + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() ) - msg = f"Your new schedule is `{Timeblock.generate_schedule(timeblocks)}`." - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) + + if channel is None or not channel.active: + logger.info( + f"Pairbot is not added to guild \"{interaction.guild.name}\", channel \"#{interaction.channel}\"." + ) + await interaction.response.send_message( + f"Pairbot is not added to \"#{interaction.channel.name}\"." + ) + else: + channel.active = False + session.commit() + logger.info( + f"Removed Pairbot from guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." + ) + await interaction.response.send_message( + f"Removed Pairbot from \"#{interaction.channel.name}\"." + ) -@tree.command( - name="unsubscribe-all", description="Remove all timeblocks for pair programming." +@client.command( + name="subscribe", + description="Subscribe to pair programming (every day if no weekday specified)." ) -async def _unsubscribe_all(interaction: discord.Interaction): - try: - db.unsubscribe(interaction.guild_id, interaction.user.id) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} called unsubscribe-all." - ) - msg = "Your pairing subscriptions have been removed. To rejoin, call `/subscribe` again." - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command(name="schedule", description="View your pairing schedule.") -async def _schedule(interaction: discord.Interaction): - try: - timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) - schedule = Timeblock.generate_schedule(timeblocks) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} queried schedule {schedule}." +@discord.app_commands.guild_only() +async def _subscribe( + interaction: discord.Interaction, + weekday: Optional[Weekday], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + session = client.make_orm_session() + with session.begin(): + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() ) - msg = ( - f"Your current schedule is `{schedule}`. " - "You can call `/subscribe` or `/unsubscribe` to modify it." + if channel is None: + await interaction.response.send_message( + f"Pairbot is not active in this channel.", + ephemeral=True + ) + return + + schedule = ( + session.query(Schedule) + .filter(Schedule.channel_id == interaction.channel_id) + .filter(Schedule.user_id == interaction.user.id) + .one_or_none() ) - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) + if schedule is None: + schedule = Schedule( + channel_id = interaction.channel_id, + user_id = interaction.user.id, + ) + session.add(schedule) + + if weekday is not None: + if schedule[weekday] == True: + await interaction.response.send_message( + f"You are already subscribed to pair programming on {str(weekday)} in #{interaction.channel.name}.", + ephemeral=True + ) + return + else: + schedule[weekday] = True + + session.commit() + + logger.info( + f"Subscribed user \"{interaction.user.name}\" to pair programming on {str(weekday)} in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." + ) -@tree.command( - name="set-channel", description="Set a channel for bot messages (admin only)." + msg = f"Successfully subscribed to pair programming on {str(weekday)} in #{interaction.channel.name}." + else: + if len(schedule.days_available) == 7: + await interaction.response.send_message( + f"You are already subscribed to pair programming every day.", + ephemeral=True + ) + return + for day in Weekday: + schedule[day] = True + + session.commit() + + logger.info( + f"Subscribed user \"{interaction.user.name}\" to daily pair programming in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." + ) + + msg = f"Successfully subscribed to daily pair programming in #{interaction.channel.name}." + await interaction.response.send_message(msg, ephemeral=True) + + +@client.command( + name="unsubscribe", + description="Unsubscribe from pair programming (every day if no weekday specified)." ) -@app_commands.checks.has_permissions(administrator=True) -async def _set_channel(interaction: discord.Interaction, channel: discord.TextChannel): - try: - guild_to_channel = read_guild_to_channel(GUILDS_PATH) - guild_to_channel[str(interaction.guild_id)] = channel.id - with open(GUILDS_PATH, "w") as f: - json.dump(guild_to_channel, f) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} set-channel C:{channel.id}." +@discord.app_commands.guild_only() +async def _unsubscribe( + interaction: discord.Interaction, + weekday: Optional[Weekday], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + session = client.make_orm_session() + with session.begin(): + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() ) - msg = f"Successfully set bot channel to `{channel.name}`." - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) + if channel is None: + await interaction.response.send_message( + f"Pairbot is not active in this channel.", + ephemeral=True + ) + return + schedule = ( + session.query(Schedule) + .filter(Schedule.channel_id == interaction.channel_id) + .filter(Schedule.user_id == interaction.user.id) + .one_or_none() + ) -@tree.command( - name="pairwith", - description="Start an immediate pairing session with another member.", + if schedule is None: + await interaction.response.send_message( + f"You are already not subscribed to pair programming in #{interaction.channel.name}.", + ephemeral=True + ) + return + + if weekday is not None: + if not schedule[weekday]: + await interaction.response.send_message( + f"You are already not subscribed to pair programming on {str(weekday)} in #{interaction.channel.name}.", + ephemeral=True + ) + return + + schedule[weekday] = False + session.commit() + + logger.info( + f"Unsubscribed user \"{interaction.user.name}\" from pair programming on {str(weekday)} in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." + ) + + msg = f"Successfully unsubscribed from pair programming on {str(weekday)} in #{interaction.channel.name}." + else: + if len(schedule.days_available) == 0: + await interaction.response.send_message( + f"You are already not subscribed to pair programming.", + ephemeral=True + ) + return + for day in Weekday: + schedule[day] = False + session.commit() + + logger.info( + f"Unsubscribed user \"{interaction.user.name}\" from all pair programming in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." + ) + + msg = f"Successfully unsubscribed from all pair programming in #{interaction.channel.name}." + + await interaction.response.send_message(msg, ephemeral=True) + + +@client.command( + name="skip", + description="Mark yourself as unavailable for pair programming on some date in the future)." ) -async def _pairwith(interaction: discord.Interaction, user: discord.Member): - try: - guild_to_channel = read_guild_to_channel(GUILDS_PATH) - channel_id = guild_to_channel[str(interaction.guild_id)] - channel = client.get_channel(channel_id) - users = [interaction.user, user] - notify_msg = ( - f"<@{interaction.user.id}> has started an on-demand pair with you, <@{user.id}>. " - "Happy pairing! :computer:" - ) - await create_group_thread(interaction.guild_id, users, channel, notify_msg) - logger.info( - f"G:{interaction.guild_id} C:{channel.id} on-demand paired U:{interaction.user.id} with {user.id}." +@discord.app_commands.describe(human_date="A human-readable date like \"tomorrow\" or \"January 1\".") +@discord.app_commands.guild_only() +async def _skip( + interaction: discord.Interaction, + human_date: Optional[str], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + session = client.make_orm_session() + with session.begin(): + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() ) - await interaction.response.send_message( - f"Thread with {get_user_name(user)} created in channel `{channel.name}`.", - ephemeral=True, + if channel is None: + await interaction.response.send_message( + f"Pairbot is not active in this channel.", + ephemeral=True + ) + return + + schedule = ( + session.query(Schedule) + .filter(Schedule.channel_id == interaction.channel_id) + .filter(Schedule.user_id == interaction.user.id) + .one_or_none() ) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) + if schedule is None or len(schedule.days_available) == 0: + await interaction.response.send_message( + f"You are not subscribed to pair programming in #{interaction.channel.name}.", + ephemeral=True + ) + return -async def dm_user(user: discord.User, msg: str): - try: - channel = await user.create_dm() - await channel.send(msg) - except Exception as e: - logger.error(e, exc_info=True) + if human_date is None: + adjustment_date = None + current_weekday = datetime.now().weekday() + for i in range(0, 7): + weekday = Weekday((current_weekday + i) % 7) + if schedule[weekday]: + adjustment_date = datetime.now().date() + timedelta(days=i) + break + + assert adjustment_date is not None + else: + # Clean things up for the parser + cleaned_date = human_date.lower() + cleaned_date = cleaned_date.replace("next", "") + + # Try to parse the date + adjustment_datetime = dateparser.parse( + cleaned_date, + settings={ + "PREFER_DATES_FROM": "future", + "RELATIVE_BASE": datetime.now(), + }, + languages=["en"], + ) + if adjustment_datetime is None: + await interaction.response.send_message( + f"Could not parse date \"{human_date}\".", + ephemeral=True + ) + return + adjustment_date = adjustment_datetime.date() + if adjustment_datetime < datetime.now(): + await interaction.response.send_message( + f"Cannot skip a date in the past: {adjustment_date.strftime('%A %B %d, %Y')}.", + ephemeral=True + ) + return -async def create_group_thread( - guild_id: int, - users: List[discord.User], - channel: discord.TextChannel, - notify_msg: str, -): - # @ notifying users in a private thread invites them - # so `notify_msg` must notify for this to work - userids = [user.id for user in users] - thread_id = pairings_db.query_userids(guild_id, userids, channel.id) - thread = None - if thread_id is not None: - logger.debug(f"Found existing thread {thread_id} for G:{guild_id} U:{userids}") - try: - guild = client.get_guild(guild_id) - thread = await guild.fetch_channel(thread_id) - except discord.errors.NotFound: - logger.debug(f"Couldn't fetch thread {thread_id}, maybe deleted?") - pairings_db.delete(guild_id, userids, channel.id, thread_id) - if thread is None: - title = ", ".join(get_user_name(user) for user in users) - thread = await channel.create_thread( - name=f"{title}", auto_archive_duration=10080 + weekday = Weekday(adjustment_date.weekday()) + if schedule[weekday] == False: + await interaction.response.send_message( + f"You are not subscribed to pair programming in #{interaction.channel.name} on {weekday}.", + ephemeral=True + ) + return + + adjustment = ( + session.query(ScheduleAdjustment) + .filter(ScheduleAdjustment.channel_id == interaction.channel_id) + .filter(ScheduleAdjustment.user_id == interaction.user.id) + .filter(func.DATE(ScheduleAdjustment.date) == adjustment_date) + .one_or_none() ) - logger.debug(f"Created new thread {thread.id} for G:{guild_id} U:{userids}") - pairings_db.insert(guild_id, userids, channel.id, thread.id) - else: - logger.debug(f"Found existing thread {thread_id} for G:{guild_id} U:{userids}") - guild = client.get_guild(guild_id) - thread = await guild.fetch_channel(thread_id) - await thread.send(notify_msg) + if adjustment is None: + adjustment = ScheduleAdjustment( + channel_id=interaction.channel_id, + user_id=interaction.user.id, + date=adjustment_date, + available=False, + ) + session.add(adjustment) + else: + if adjustment.available == False: + await interaction.response.send_message( + f"You already skipped pairing on {adjustment_date.strftime('%A %B %d, %Y')}.", + ephemeral=True + ) + return + adjustment.available = False + session.commit() + + logger.info( + f"Skipped pair programming on {adjustment_date.strftime('%A %B %d, %Y')} for user \"{interaction.user.name}\" in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." + ) + + msg = f"Successfully skipped pair programming on {adjustment_date.strftime('%A %B %d, %Y')}." + await interaction.response.send_message(msg, ephemeral=True) -async def on_tree_error( - interaction: discord.Interaction, error: app_commands.AppCommandError + +@client.command( + name="unskip", + description="Mark yourself as available for pair programming on some date in the future." +) +@discord.app_commands.describe(human_date="A human-readable date like \"tomorrow\" or \"January 1\".") +@discord.app_commands.guild_only() +async def _unskip( + interaction: discord.Interaction, + human_date: Optional[str], ): - if isinstance(error, app_commands.CommandOnCooldown): - return await interaction.response.send_message( - f"Command is currently on cooldown, try again in {error.retry_after:.2f} seconds.", - ephemeral=True, + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + session = client.make_orm_session() + with session.begin(): + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() ) - elif isinstance(error, app_commands.MissingPermissions): - return await interaction.response.send_message( - "You don't have the permissions to do that.", ephemeral=True - ) - else: - raise error - - -@tasks.loop(hours=1) -async def pairing_cron(): - def should_run(): - now = datetime.utcnow() - hour = now.time().hour - logger.debug(f"Checking pairing job at UTC:{now}.") - return hour == 8 - - if should_run(): - await run_pairing() - - -async def run_pairing(): - now = datetime.utcnow() - print(now) - logger.debug(f"Running pairing job at UTC:{now}.") - weekday = now.weekday() - weekday_map = { - 0: Timeblock.Monday, - 1: Timeblock.Tuesday, - 2: Timeblock.Wednesday, - 3: Timeblock.Thursday, - 4: Timeblock.Friday, - 5: Timeblock.Saturday, - 6: Timeblock.Sunday, - } - timeblock = weekday_map[weekday] - for guild in client.guilds: - await pair(guild.id, timeblock) - # weekly Monday match - if weekday == 0: - await pair(guild.id, Timeblock.WEEK) - - -async def pair(guild_id: int, timeblock: Timeblock): - try: - userids = db.query_timeblock(guild_id, timeblock) - users = [client.get_user(userid) for userid in userids] - # Users may leave the server without unsubscribing - # TODO: listen to that event and drop them from the table - users = list(filter(None, users)) - logger.info( - f"Pairing for G:{guild_id} T:{timeblock.name} with {len(users)}/{len(userids)} users." + if channel is None: + await interaction.response.send_message( + f"Pairbot is not active in this channel.", + ephemeral=True + ) + return + + schedule = ( + session.query(Schedule) + .filter(Schedule.channel_id == interaction.channel_id) + .filter(Schedule.user_id == interaction.user.id) + .one_or_none() ) - guild_to_channel = read_guild_to_channel(GUILDS_PATH) - channel = client.get_channel(guild_to_channel[str(guild_id)]) - if len(users) < 2: - for user in users: - logger.info( - f"G:{guild_id} T:{timeblock.name} pair failed, dming U:{user.id}." - ) - msg = ( - f"Thanks for signing up for pairing this {timeblock}. " - "Unfortunately, there was nobody else available this time." - ) - await dm_user(user, msg) - await channel.send( - f"Not enough signups this {timeblock}. Try `/subscribe` to sign up!" + + if schedule is None or len(schedule.days_available) == 0: + await interaction.response.send_message( + f"You are not subscribed to pair programming in #{interaction.channel.name}.", + ephemeral=True ) return - random.shuffle(users) - groups = [users[i :: len(users) // 2] for i in range(len(users) // 2)] - for group in groups: - notify_msg = ", ".join(f"<@{user.id}>" for user in group) - notify_msg = f"{notify_msg}: you've been matched together for this {timeblock}. Happy pairing! :computer:" - await create_group_thread(guild_id, group, channel, notify_msg) - logger.info( - f"G:{guild_id} C:{channel.id} paired U:{[user.id for user in group]}." + if human_date is None: + adjustment_date = None + current_weekday = datetime.now().weekday() + for i in range(0, 7): + weekday = Weekday((current_weekday + i) % 7) + if schedule[weekday]: + adjustment_date = datetime.now().date() + timedelta(days=i) + break + + assert adjustment_date is not None + else: + # Clean things up for the parser + cleaned_date = human_date.lower() + cleaned_date = cleaned_date.replace("next", "") + + # Try to parse the date + adjustment_datetime = dateparser.parse( + cleaned_date, + settings={ + "PREFER_DATES_FROM": "future", + "RELATIVE_BASE": datetime.now(), + }, + languages=["en"], + ) + if adjustment_datetime is None: + await interaction.response.send_message( + f"Could not parse date \"{human_date}\".", + ephemeral=True + ) + return + + adjustment_date = adjustment_datetime.date() + if adjustment_datetime < datetime.now(): + await interaction.response.send_message( + f"Cannot unskip a date in the past: {adjustment_date.strftime('%A %B %d, %Y')}.", + ephemeral=True + ) + return + + weekday = Weekday(adjustment_date.weekday()) + if schedule[weekday] == False: + await interaction.response.send_message( + f"You are not subscribed to pair programming in #{interaction.channel.name} on {weekday}.", + ephemeral=True + ) + return + + adjustment = ( + session.query(ScheduleAdjustment) + .filter(ScheduleAdjustment.channel_id == interaction.channel_id) + .filter(ScheduleAdjustment.user_id == interaction.user.id) + .filter(func.DATE(ScheduleAdjustment.date) == adjustment_date) + .one_or_none() + ) + + if adjustment is None: + adjustment = ScheduleAdjustment( + channel_id=interaction.channel_id, + user_id=interaction.user.id, + date=adjustment_date, + available=True, ) - await channel.send( - f"Pairings for {len(users)} users have been sent out for this {timeblock}. Try `/subscribe` to sign up!" + session.add(adjustment) + else: + if adjustment.available == True: + await interaction.response.send_message( + f"You already unskipped pairing on {adjustment_date.strftime('%A %B %d, %Y')}.", + ephemeral=True + ) + return + adjustment.available = True + session.commit() + + logger.info( + f"Unskipped pair programming on {adjustment_date.strftime('%A %B %d, %Y')} for user \"{interaction.user.name}\" in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." + ) + + msg = f"Successfully unskipped pair programming on {adjustment_date.strftime('%A %B %d, %Y')}." + await interaction.response.send_message(msg, ephemeral=True) + + +@client.command( + name="viewschedule", + description="View your pair programming schedule." +) +@discord.app_commands.guild_only() +async def _view_schedule( + interaction: discord.Interaction, +): + assert interaction.guild is not None + + session = client.make_orm_session() + with session.begin(): + schedules = ( + session.query(Schedule) + .filter(Schedule.user_id == interaction.user.id) + .all() ) - except Exception as e: - logger.error(e, exc_info=True) + adjustments = ( + session.query(ScheduleAdjustment) + .filter(ScheduleAdjustment.user_id == interaction.user.id) + .all() + ) -def local_setup(): - try: - read_guild_to_channel(GUILDS_PATH) - except Exception: - with open(GUILDS_PATH, "w") as f: - json.dump({}, f) + guild_schedule = { + interaction.guild.get_channel(schedule.channel_id): [str(day) for day in schedule.days_available] + for schedule in schedules + } + + skipped = dict() + unskipped = dict() + + for adjustment in adjustments: + channel_name = interaction.guild.get_channel(adjustment.channel_id) + if adjustment.available == False: + if channel_name not in skipped: + skipped[channel_name] = [] + skipped[channel_name].append(adjustment.date.strftime("%a %b %d")) + if adjustment.available == True: + if channel_name not in unskipped: + unskipped[channel_name] = [] + unskipped[channel_name].append(adjustment.date.strftime("%a %b %d")) + + if len(guild_schedule) == 0 or sum([len(days) for days in guild_schedule.values()]) == 0: + msg = "You are not subscribed to pair programming." + elif len(guild_schedule) == 1: + channel_name, schedule = guild_schedule.popitem() + msg = f"You are subscribed to pair programming in #{channel_name} on {', '.join(schedule)}" + if channel_name in skipped and len(skipped[channel_name]) > 0: + msg += f" (skipping {', '.join(skipped[channel_name])})" + else: + msg = "You are subscribed to pair programming in the following channels:\n" + for channel_name, schedule in guild_schedule.items(): + msg += f"* #{channel_name}: {', '.join(schedule)}" + if channel_name in skipped and len(skipped[channel_name]) > 0: + msg += f" (skipping {', '.join(skipped[channel_name])})" + msg += "\n" + + await interaction.response.send_message(msg, ephemeral=True) + + +@client.command( + name="pairwith", + description="Start a pairing session with another channel member." +) +@discord.app_commands.guild_only() +async def _pair_with( + interaction: discord.Interaction, + user: discord.Member, +): + await interaction.response.send_message("Works", ephemeral=True) -@client.event -async def on_ready(): - local_setup() - await client.wait_until_ready() - tree.on_error = on_tree_error - for guild in client.guilds: - tree.copy_global_to(guild=guild) - await tree.sync(guild=guild) - print("Code sync complete!") - pairing_cron.start() - print("Starting cron loop...") - logger.info("Bot started.") +@discord.ext.tasks.loop(time=config.PAIRING_TIME) +async def make_groups(): + pass def run(): - client.run(BOT_TOKEN) + client.run(config.DISCORD_BOT_TOKEN) diff --git a/pairbot/config.py b/pairbot/config.py new file mode 100644 index 0000000..5001e94 --- /dev/null +++ b/pairbot/config.py @@ -0,0 +1,67 @@ +"""config.py + +This module contains pairbot's configuration settings, stored as module-level variables +with uppercase names. +""" + +import os +from pathlib import Path +from datetime import time +import logging.config + +import dotenv +import discord + +BASE_DIR = Path(__file__).resolve().parent.parent + +class ConfigError(Exception): + """Exception raised for errors in application configuration.""" + pass + +ENV = os.environ.get("ENV", "development") +if ENV not in ["development", "testing", "production"]: + raise ConfigError(f"Invalid environment: {ENV}") + +DOTENV_PATH = os.path.join(BASE_DIR, f".env.{ENV}") +if not os.path.isfile(DOTENV_PATH): + raise ConfigError(f"Could not find dotenv file {DOTENV_PATH}.") +dotenv.load_dotenv(DOTENV_PATH) + +# Load required environment variables +try: + DISCORD_BOT_TOKEN = os.environ["DISCORD_BOT_TOKEN"] + DATABASE_URL = os.environ["DATABASE_URL"] + PAIRING_TIME = time.fromisoformat(os.environ["PAIRING_TIME"]) +except KeyError as e: + raise ConfigError(f"Missing environment variable: {e.args[0]}.") + +# Configure logger +LOGGING = { + "version": 1, + "formatters": { + "default": { + "format": "%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", + "datefmt": "%H:%M:%S", + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + } + }, + "root": { + "handlers": ["console"], + "level": "DEBUG", + } +} + +logging.config.dictConfig(LOGGING) + +logger = logging.getLogger(__name__) + +# Log config information +logger.info(f"Running in {ENV} mode.") + +# Disable PyNaCl warning +discord.VoiceClient.warn_nacl = False diff --git a/pairbot/models.py b/pairbot/models.py new file mode 100644 index 0000000..b332273 --- /dev/null +++ b/pairbot/models.py @@ -0,0 +1,129 @@ +"""models.py + +This module contains Peewee ORM models and their respective logic. +""" +import calendar +import enum +from datetime import datetime + +from sqlalchemy import ( + PrimaryKeyConstraint, + Boolean, + Column, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, +) + + +class Base(DeclarativeBase): + pass + + +class PairingChannel(Base): + """A Discord channel in which Pairbot has been allowed to operate.""" + __tablename__ = "channel" + + guild_id: Mapped[int] + """The Discord guild ID.""" + + channel_id: Mapped[int] + """The Discord channel ID.""" + + active: Mapped[bool] + """Whether Pairbot is active in the channel.""" + + leetcode_integration: Mapped[bool] + """Whether leetcode integration is active in the channel or not.""" + + __table_args__ = ( + PrimaryKeyConstraint("guild_id", "channel_id",), + ) + + +class Weekday(enum.IntEnum): + """An enum representing the day of the week (0-6 starting from Monday).""" + Monday = 0 + Tuesday = 1 + Wednesday = 2 + Thursday = 3 + Friday = 4 + Saturday = 5 + Sunday = 6 + + def __str__(self): + return calendar.day_name[self] + + +class Schedule(Base): + """Represents the availability of a user in a channel on a given day of the week""" + __tablename__ = "schedule" + + channel_id: Mapped[int] + """The Discord channel ID.""" + + user_id: Mapped[int] + """The Discord user ID.""" + + available_Monday: Mapped[bool] = mapped_column(Boolean, default=False) + available_Tuesday: Mapped[bool] = mapped_column(Boolean, default=False) + available_Wednesday: Mapped[bool] = mapped_column(Boolean, default=False) + available_Thursday: Mapped[bool] = mapped_column(Boolean, default=False) + available_Friday: Mapped[bool] = mapped_column(Boolean, default=False) + available_Saturday: Mapped[bool] = mapped_column(Boolean, default=False) + available_Sunday: Mapped[bool] = mapped_column(Boolean, default=False) + + __table_args__ = ( + PrimaryKeyConstraint("channel_id", "user_id"), + ) + + # hack to get Schedule[day_of_week] to work + def __getitem__(self, day_of_week: Weekday): + return getattr(self, f"available_{day_of_week}") + + def __setitem__(self, day_of_week: Weekday, value: bool): + setattr(self, f"available_{day_of_week}", value) + + @property + def days_available(self): + return [day for day in Weekday if self[day] == True] + +class ScheduleAdjustment(Base): + """Represents an adjustment to a user's availability on a specific date.""" + __tablename__ = "schedule_adjustment" + + channel_id: Mapped[int] + """The Discord channel ID.""" + + user_id: Mapped[int] + """The Discord user ID.""" + + date: Mapped[datetime] + """The date on which the user's availability is set.""" + + available: Mapped[bool] + """Whether the user is available on this date.""" + + __table_args__ = ( + PrimaryKeyConstraint("channel_id", "user_id", "date"), + ) + + +class Thread(Base): + """Represents a user's membership in a Discord thread created by Pairbot.""" + __tablename__ = "pairing" + + channel_id: Mapped[int] + """The Discord channel ID.""" + + thread_id: Mapped[int] + """The Discord thread ID""" + + user_id: Mapped[int] + """The Discord user ID.""" + + __table_args__ = ( + PrimaryKeyConstraint("channel_id", "thread_id", "user_id"), + ) diff --git a/pairbot/utils.py b/pairbot/utils.py index fb67c83..ab3ee51 100644 --- a/pairbot/utils.py +++ b/pairbot/utils.py @@ -2,6 +2,7 @@ import json import discord +import discord.ext.commands def parse_args(): diff --git a/pyproject.toml b/pyproject.toml index ef28a21..cd7e039 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,11 @@ name = "pairbot" version = "0.0.1" readme = "README.md" dependencies = [ + "SQLAlchemy", + "alembic", "discord", - "python-dotenv" + "python-dotenv", + "dateparser" ] [project.urls] @@ -21,6 +24,7 @@ pairbot = "pairbot:run" # Testing [tool.hatch.envs.default] +python = "3.11" dependencies = [ "coverage[toml]>=6.5", "pytest", From 2a0ba0997c481910f3023a8036bd627f3d8604ca Mon Sep 17 00:00:00 2001 From: Yaxel Date: Thu, 30 Nov 2023 22:57:35 -0800 Subject: [PATCH 2/3] Rewrite part 2 --- .gitignore | 3 + .../versions/c2ab7b1ce2f4_default_values.py | 30 - ..._migrations_v.py => fab85b27c1d3_nuked.py} | 34 +- pairbot/__init__.py | 2 +- pairbot/app.py | 99 +++ pairbot/client.py | 74 +- pairbot/context.py | 57 ++ pairbot/discord.py | 90 +++ pairbot/globals.py | 6 + pairbot/main.py | 750 ++++++++++++++++++ pairbot/models.py | 73 +- 11 files changed, 1138 insertions(+), 80 deletions(-) delete mode 100644 alembic/versions/c2ab7b1ce2f4_default_values.py rename alembic/versions/{d588346e0112_nuked_migrations_v.py => fab85b27c1d3_nuked.py} (62%) create mode 100644 pairbot/app.py create mode 100644 pairbot/context.py create mode 100644 pairbot/discord.py create mode 100644 pairbot/globals.py create mode 100644 pairbot/main.py diff --git a/.gitignore b/.gitignore index ff3cfc4..b57285a 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,7 @@ cython_debug/ *.log .env +.env.* data/ + +*.sqlite3 diff --git a/alembic/versions/c2ab7b1ce2f4_default_values.py b/alembic/versions/c2ab7b1ce2f4_default_values.py deleted file mode 100644 index 1b47059..0000000 --- a/alembic/versions/c2ab7b1ce2f4_default_values.py +++ /dev/null @@ -1,30 +0,0 @@ -"""default values - -Revision ID: c2ab7b1ce2f4 -Revises: d588346e0112 -Create Date: 2023-11-25 00:16:18.559944 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'c2ab7b1ce2f4' -down_revision: Union[str, None] = 'd588346e0112' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - pass - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - pass - # ### end Alembic commands ### diff --git a/alembic/versions/d588346e0112_nuked_migrations_v.py b/alembic/versions/fab85b27c1d3_nuked.py similarity index 62% rename from alembic/versions/d588346e0112_nuked_migrations_v.py rename to alembic/versions/fab85b27c1d3_nuked.py index 627ff9c..5678b18 100644 --- a/alembic/versions/d588346e0112_nuked_migrations_v.py +++ b/alembic/versions/fab85b27c1d3_nuked.py @@ -1,8 +1,8 @@ -"""Nuked migrations :v +"""nuked -Revision ID: d588346e0112 +Revision ID: fab85b27c1d3 Revises: -Create Date: 2023-11-25 00:12:31.330617 +Create Date: 2023-11-25 23:32:30.908140 """ from typing import Sequence, Union @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. -revision: str = 'd588346e0112' +revision: str = 'fab85b27c1d3' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -23,33 +23,33 @@ def upgrade() -> None: op.create_table('channel', sa.Column('guild_id', sa.Integer(), nullable=False), sa.Column('channel_id', sa.Integer(), nullable=False), - sa.Column('active', sa.Boolean(), nullable=False), sa.Column('leetcode_integration', sa.Boolean(), nullable=False), sa.PrimaryKeyConstraint('guild_id', 'channel_id') ) op.create_table('pairing', sa.Column('channel_id', sa.Integer(), nullable=False), sa.Column('thread_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.PrimaryKeyConstraint('channel_id', 'thread_id', 'user_id') + sa.Column('user_1_id', sa.Integer(), nullable=False), + sa.Column('user_2_id', sa.Integer(), nullable=False), + sa.Column('user_3_id', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('channel_id', 'thread_id') ) op.create_table('schedule', sa.Column('channel_id', sa.Integer(), nullable=False), sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('available_Monday', sa.Boolean(), nullable=False), - sa.Column('available_Tuesday', sa.Boolean(), nullable=False), - sa.Column('available_Wednesday', sa.Boolean(), nullable=False), - sa.Column('available_Thursday', sa.Boolean(), nullable=False), - sa.Column('available_Friday', sa.Boolean(), nullable=False), - sa.Column('available_Saturday', sa.Boolean(), nullable=False), - sa.Column('available_Sunday', sa.Boolean(), nullable=False), + sa.Column('available_0', sa.Boolean(), nullable=False), + sa.Column('available_1', sa.Boolean(), nullable=False), + sa.Column('available_2', sa.Boolean(), nullable=False), + sa.Column('available_3', sa.Boolean(), nullable=False), + sa.Column('available_4', sa.Boolean(), nullable=False), + sa.Column('available_5', sa.Boolean(), nullable=False), + sa.Column('available_6', sa.Boolean(), nullable=False), sa.PrimaryKeyConstraint('channel_id', 'user_id') ) - op.create_table('schedule_adjustment', + op.create_table('skip', sa.Column('channel_id', sa.Integer(), nullable=False), sa.Column('user_id', sa.Integer(), nullable=False), sa.Column('date', sa.DateTime(), nullable=False), - sa.Column('available', sa.Boolean(), nullable=False), sa.PrimaryKeyConstraint('channel_id', 'user_id', 'date') ) # ### end Alembic commands ### @@ -57,7 +57,7 @@ def upgrade() -> None: def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('schedule_adjustment') + op.drop_table('skip') op.drop_table('schedule') op.drop_table('pairing') op.drop_table('channel') diff --git a/pairbot/__init__.py b/pairbot/__init__.py index eb123b4..86966d1 100644 --- a/pairbot/__init__.py +++ b/pairbot/__init__.py @@ -1 +1 @@ -from .client import run +from .main import run diff --git a/pairbot/app.py b/pairbot/app.py new file mode 100644 index 0000000..4348ff0 --- /dev/null +++ b/pairbot/app.py @@ -0,0 +1,99 @@ +import pdb +import sys + +from typing import ( + Any, + Callable, + Concatenate, + Coroutine, + ParamSpec, + TypeVar, +) + +import logging +from functools import wraps +from contextvars import ContextVar + +import discord + +logger = logging.getLogger(__name__) + +# Globally accessible context variables +interaction_ctx: ContextVar[discord.Interaction[Any]] = ContextVar("interaction") + + +class InteractionContext: + def __init__(self, interaction: discord.Interaction[Any]): + self.interaction = interaction + self.token = None + + def __enter__(self) -> None: + self.token = interaction_ctx.set(self.interaction) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self.token is not None: + interaction_ctx.reset(self.token) + + +class InteractionLoggingAdapter(logging.LoggerAdapter): + """A custom formatter for logging inside an interaction context""" + def process(self, msg, kwargs): + interaction = interaction_ctx.get(None) + new_msg = "" + if interaction is not None: + new_msg += "Interaction" + if interaction.guild is not None: + new_msg += f" in guild \"{interaction.guild.name}\" ({interaction.guild.id})" + if interaction.channel is not None and isinstance(interaction.channel, discord.TextChannel): + new_msg += f" channel \"#{interaction.channel.name} ({interaction.channel.id})\"" + new_msg += f" by user \"@{interaction.user.name}\" ({interaction.user.id}): " + new_msg += str(msg) + return new_msg, kwargs + + +# Type fuckery for the command decorator +T = TypeVar("T") +P = ParamSpec("P") +CommandCallback = Callable[Concatenate[discord.Interaction[Any], P], Coroutine[Any, Any, T]] + +class DiscordApp: + """A wrapper around a discord client instance.""" + def __init__(self, intents: discord.Intents, **options: Any): + self.client = discord.Client(intents=intents, **options) + self.command_tree = discord.app_commands.CommandTree(self.client) + self.logger = InteractionLoggingAdapter(logger) + + # register on_ready event + @self.client.event + async def on_ready(): + for guild in self.client.guilds: + self.logger.info("Copying command tree to guild \"%s\" (%d)", guild.name, guild.id) + self.command_tree.copy_global_to(guild=guild) + await self.command_tree.sync(guild=guild) + self.logger.info("Application ready.") + + def command(self, **options: Any): + """Wrapper around discord slash commands.""" + def decorator(callback: CommandCallback): + @self.command_tree.command(**options) + @wraps(callback) + async def wrapper(i: discord.Interaction[Any], *args: P.args, **kwargs: P.kwargs) -> None: + assert i.command is not None + pretty_kwargs = ( + " with arguments { " + + ", ".join((f"{key}=\"{str(value)}\"" for key, value in kwargs.items())) + + " }" + if len(kwargs) > 0 else " with no arguments" + ) + self.logger.info("Executing slash command /%s%s", i.command.name, pretty_kwargs) + try: + with InteractionContext(i): + await callback(i, *args, **kwargs) + except Exception as e: + self.logger.error(e, exc_info=True) + await i.response.send_message("Pairbot broke somehow :v", ephemeral=True) + return wrapper + return decorator + + def run(self, token: str, *args, **kwargs): + self.client.run(token, *args, **kwargs) diff --git a/pairbot/client.py b/pairbot/client.py index 49c1fdc..5f93540 100644 --- a/pairbot/client.py +++ b/pairbot/client.py @@ -19,7 +19,8 @@ import sqlalchemy from sqlalchemy import ( create_engine, - func + func, + or_, ) import sqlalchemy.orm from sqlalchemy.orm import ( @@ -44,7 +45,7 @@ PairingChannel, Schedule, ScheduleAdjustment, - Thread, + Pairing, ) logger = logging.getLogger(__name__) @@ -668,7 +669,74 @@ async def _pair_with( interaction: discord.Interaction, user: discord.Member, ): - await interaction.response.send_message("Works", ephemeral=True) + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + if user.id == interaction.user.id: + await interaction.response.send_message( + f"You cannot pair with yourself.", + ephemeral=True + ) + return + + session = client.make_orm_session() + with session.begin(): + pairing = ( + session.query(Pairing) + .filter(or_( + Pairing.user_1_id == interaction.user.id, + Pairing.user_2_id == interaction.user.id, + )) + .filter(or_( + Pairing.user_1_id == user.id, + Pairing.user_2_id == user.id, + )) + .one_or_none() + ) + + usernames = sorted([user.name for user in (interaction.user, user)]) + + if pairing is None: + # Make new discord thread + thread = await interaction.channel.create_thread( + name=f"{usernames[0]} & {usernames[1]}", + ) + + pairing = Pairing( + channel_id=interaction.channel.id, + thread_id=thread.id, + user_1_id=interaction.user.id, + user_2_id=user.id, + ) + session.add(pairing) + session.commit() + else: + thread = interaction.guild.get_thread(pairing.thread_id) + if thread is None: + session.delete(pairing) + session.commit() + + # Make new discord thread + thread = await interaction.channel.create_thread( + name=f"{usernames[0]} & {usernames[1]}", + ) + + pairing = Pairing( + channel_id=interaction.channel.id, + thread_id=thread.id, + user_1_id=interaction.user.id, + user_2_id=user.id, + ) + session.add(pairing) + session.commit() + + await thread.send( + f"<@{interaction.user.id}> has started a pairing session with you, <@{user.id}>. Happy pairing! :computer:" + ) + await interaction.response.send_message( + f"Successfully created pairing thread with <@{user.id}>", + ephemeral=True + ) @discord.ext.tasks.loop(time=config.PAIRING_TIME) diff --git a/pairbot/context.py b/pairbot/context.py new file mode 100644 index 0000000..f5797c0 --- /dev/null +++ b/pairbot/context.py @@ -0,0 +1,57 @@ + +import contextvars +from types import TracebackType +from typing import ( + Any, + Self, +) + +from discord import Interaction + +from . import globals +from .app import App + +class AppContext: + """The app context contains application-specific information (similar to Flask's application + context).""" + + def __init__( + self, + app: App, + ): + self.app = app + self._cv_tokens: list[contextvars.Token] = [] + + def push(self) -> None: + """Binds the app context to the current context.""" + self._cv_tokens.append(globals.app.set(self)) + + def pop(self, exc: BaseException | None): + ctx = globals.app.get() + assert ctx is self + globals.app.reset(self._cv_tokens.pop()) + + def __enter__(self) -> Self: + self.push() + return self + + def __exit__( + self, + exc_type: type | None, + exc_value: BaseException | None, + tb: TracebackType | None + ) -> None: + self.pop(exc_value) + + +class InteractionContext: + """The interaction context contains interaction-specific information (similar to Flask's request + context). It is created and pushed at the beginning of an interaction, and popped at the end.""" + + def __init__( + self, + app: App, + interaction: Interaction[Any], + ): + self.app = app + self.interaction = interaction diff --git a/pairbot/discord.py b/pairbot/discord.py new file mode 100644 index 0000000..0338556 --- /dev/null +++ b/pairbot/discord.py @@ -0,0 +1,90 @@ +"""discord.py + +Utilities for interfacing with the Discord API.""" + +from abc import ABC +from typing import ( + Any, + Callable, + Concatenate, + Coroutine, + ParamSpec, + TypeVar, + TypedDict, + Union, +) + +import functools +import logging + +import discord + +logger = logging.getLogger(__name__) + + +class AppContext(TypedDict): + logger_adapter: Union[logging.Logger, logging.LoggerAdapter] + + +class AppCommandLoggerAdapter(logging.LoggerAdapter): + """Custom LoggerAdapter for adding context to logs.""" + def process(self, msg: str, kwargs: dict): + """Add guild, channel, and user IDs to slash command logs.""" + if self.extra is None: + return msg, kwargs + + guild_id = self.extra.get("guild_id") + channel_id = self.extra.get("channel_id") + user_id = self.extra.get("user_id") + + s = "" + if guild_id is not None: s += f"g: <{guild_id} " + if channel_id is not None: s += f"c: <{channel_id} " + if user_id is not None: s += f"u: <{user_id} " + s += msg + + return s, kwargs + +# Type fuckery for the command decorator +T = TypeVar("T") +P = ParamSpec("P") +CommandCallback = Callable[Concatenate[AppContext, discord.Interaction[Any], P], Coroutine[Any, Any, T]] + + +class Application(ABC): + """A wrapper around a discord client instance.""" + + def __init__( + self, + intents: discord.Intents = discord.Intents.all(), + **options: Any + ): + self.client = discord.Client(intents=intents, **options) + self.command_tree = discord.app_commands.CommandTree(self.client) + + def get_context(self, interaction: discord.Interaction[Any]) -> AppContext: + return { + "logger_adapter": AppCommandLoggerAdapter( + logger, + { + "guild_id": interaction.guild_id, + "channel_id": interaction.channel_id, + "user_id": interaction.user.id, + } + ), + } + + def command(self, **options: Any): + """Wrapper for adding logging and error-handling to application slash commands.""" + def decorator(callback: CommandCallback): + @self.command_tree.command(**options) + @functools.wraps(callback) + async def wrapper(interaction: discord.Interaction[Any], *args: P.args, **kwargs: P.kwargs) -> None: + await callback( + self.get_context(interaction), + interaction, + *args, + **kwargs + ) + return wrapper + return decorator diff --git a/pairbot/globals.py b/pairbot/globals.py new file mode 100644 index 0000000..ff9bb36 --- /dev/null +++ b/pairbot/globals.py @@ -0,0 +1,6 @@ +from contextvars import ContextVar + +from .context import AppContext, InteractionContext + +app: ContextVar[AppContext] = ContextVar("app") +interaction: ContextVar[InteractionContext] = ContextVar("interaction") diff --git a/pairbot/main.py b/pairbot/main.py new file mode 100644 index 0000000..1452a4e --- /dev/null +++ b/pairbot/main.py @@ -0,0 +1,750 @@ +import pdb + +from typing import ( + Any, + Optional, +) + +import random +from datetime import ( + date, + datetime, + timedelta, +) +import dateparser + +import discord +import discord.ext.tasks + +import sqlalchemy +import sqlalchemy.orm + +from . import config + +from .app import DiscordApp +from .models import ( + PairingChannel, + Schedule, + Skip, + Weekday, + Pairing, +) + +app = DiscordApp(intents=discord.Intents.all()) + +db_engine = sqlalchemy.create_engine(config.DATABASE_URL) +Session = sqlalchemy.orm.sessionmaker(db_engine) + + +# Utility functions + + +def get_active_pairing_channel( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, +) -> Optional[PairingChannel]: + if interaction.channel is None: return None + return ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel.id) + .one_or_none() + ) + + +def get_user_schedule( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, +) -> Optional[Schedule]: + if interaction.channel is None: return None + return ( + session.query(Schedule) + .filter(Schedule.channel_id == interaction.channel_id) + .filter(Schedule.user_id == interaction.user.id) + .one_or_none() + ) + + +def make_user_schedule( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, +) -> Schedule: + assert interaction.channel is not None + schedule = Schedule( + channel_id = interaction.channel_id, + user_id = interaction.user.id, + ) + session.add(schedule) + app.logger.info("Created new schedule.") + return schedule + + +def parse_human_readable_date(maybe_a_date: str) -> Optional[date]: + # Preprocess with some ad-hoc rules + maybe_a_date = ( + maybe_a_date + .lower() + .replace("next", "") + ) + parsed_date = dateparser.parse( + maybe_a_date, + settings={ + "PREFER_DATES_FROM": "future", + "RELATIVE_BASE": datetime.now(), + }, + languages=["en"], + ) + if parsed_date is not None: + return parsed_date.date() + + +def get_next_scheduled_date(schedule: Schedule) -> Optional[date]: + current_weekday = datetime.now().weekday() + for i in range(0, 7): + next_weekday = Weekday((current_weekday + i) % 7) + if schedule.is_available_on(next_weekday): + return datetime.now().date() + timedelta(days=i) + return None + + +def format_date(d: date): + return d.strftime('%A %B %d, %Y') + + +def get_skip( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, + d: date, +) -> Optional[Skip]: + return ( + session.query(Skip) + .filter(Skip.channel_id == interaction.channel_id) + .filter(Skip.user_id == interaction.user.id) + .filter(sqlalchemy.func.DATE(Skip.date) == d) + .one_or_none() + ) + + +def make_skip( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, + d: date, +) -> Skip: + skip = Skip( + channel_id=interaction.channel_id, + user_id=interaction.user.id, + date=d, + ) + session.add(skip) + app.logger.info("Created new skip.") + return skip + + +async def fail_on_inactive_channel( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, +) -> bool: + channel = get_active_pairing_channel(session, interaction) + if channel is None: + app.logger.info("Pairbot not active in channel.") + await interaction.response.send_message( + "Pairbot is not active in this channel.", + ephemeral=True + ) + return True + else: + return False + + +async def fail_on_existing_subscription( + interaction: discord.Interaction, + schedule: Schedule, + weekday: Optional[Weekday], +) -> bool: + assert isinstance(interaction.channel, discord.TextChannel) + if weekday is not None and schedule.is_available_on(weekday): + app.logger.info("Already subscribed on %s", weekday) + await interaction.response.send_message( + "You are already subscribed to pair programming on %s in #%s." % ( + weekday, + interaction.channel.name), + ephemeral=True + ) + return True + elif schedule.num_days_available() == 7: + app.logger.info("Already subscribed daily") + await interaction.response.send_message( + "You are already subscribed to daily pair programming in #%s." % interaction.channel.name, + ephemeral=True + ) + return True + return False + + +async def fail_on_nonexistent_subscription( + interaction: discord.Interaction, + schedule: Optional[Schedule], + weekday: Optional[Weekday], +) -> bool: + assert isinstance(interaction.channel, discord.TextChannel) + if schedule is not None and weekday is not None and not schedule.is_available_on(weekday): + app.logger.info("Not subscribed on %s", weekday) + await interaction.response.send_message( + "You are not subscribed to pair programming on %s in #%s." % ( + weekday, + interaction.channel.name), + ephemeral=True + ) + return True + elif schedule is None or schedule.num_days_available() == 0: + await interaction.response.send_message( + ("You are not subscribed to pair programming in #%s." % interaction.channel.name), + ephemeral=True + ) + return True + return False + + +def get_pairing( + session: sqlalchemy.orm.Session, + interaction: discord.Interaction, + user_1: discord.User | discord.Member, + user_2: discord.Member +) -> Optional[Pairing]: + return ( + session.query(Pairing) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_1.id, + Pairing.user_2_id == user_1.id, + )) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_2.id, + Pairing.user_2_id == user_2.id, + )) + .one_or_none() + ) + + +# Slash commands + + +@app.command( + name="addpairbot", + description="Add Pairbot to the current channel.", +) +@discord.app_commands.guild_only() +@discord.app_commands.checks.has_permissions(administrator=True) +async def _add_pairbot(interaction: discord.Interaction): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() + ) + + if channel is not None: + app.logger.info("Pairbot is already added.") + await interaction.response.send_message( + "Pairbot is already added to \"#%s\"." % interaction.channel.name, + ephemeral=True, + ) + return + + channel = PairingChannel( + guild_id = interaction.guild.id, + channel_id = interaction.channel.id, + leetcode_integration = False, # TODO + ) + session.add(channel) + + app.logger.info("Added Pairbot.") + + await interaction.response.send_message( + "Added Pairbot to \"#%s\"." % interaction.channel.name, + ephemeral=True, + ) + + +@app.command( + name="removepairbot", + description="Remove Pairbot from the current channel." +) +@discord.app_commands.guild_only() +@discord.app_commands.checks.has_permissions(administrator=True) +async def _remove_pairbot(interaction: discord.Interaction): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + channel = ( + session.query(PairingChannel) + .filter(PairingChannel.channel_id == interaction.channel_id) + .one_or_none() + ) + + if channel is None: + app.logger.info("Pairbot is not added.") + await interaction.response.send_message( + "Pairbot is not added to \"#%s\"." % interaction.channel.name, + ephemeral=True, + ) + return + + session.delete(channel) + + app.logger.info("Removed pairbot.") + + await interaction.response.send_message( + "Removed Pairbot from \"#%s\"." % interaction.channel.name, + ephemeral=True, + ) + + +@app.command( + name="subscribe", + description="Subscribe to pair programming (every day if no weekday specified)." +) +@discord.app_commands.guild_only() +async def _subscribe( + interaction: discord.Interaction, + weekday: Optional[Weekday], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if fail_on_inactive_channel(session, interaction): + return + + schedule = get_user_schedule(session, interaction) + + if schedule is None: + schedule = make_user_schedule(session, interaction) + else: + app.logger.info("Found existing schedule") + + if fail_on_existing_subscription(interaction, schedule, weekday): + return + + if weekday is not None: + schedule.set_availability_on(weekday, True) + app.logger.info("Subscribed to pair programming on %s", weekday) + await interaction.response.send_message( + "Successfully subscribed to pair programming on %s in #%s." % (weekday, interaction.channel.name), + ephemeral=True + ) + else: + schedule.set_availability_every_day(True) + app.logger.info("Subscribed to daily pair programming") + await interaction.response.send_message( + "Successfully subscribed to daily pair programming in #%s." % interaction.channel.name, + ephemeral=True + ) + + +@app.command( + name="unsubscribe", + description="Unsubscribe from pair programming (every day if no weekday specified)." +) +@discord.app_commands.guild_only() +async def _unsubscribe( + interaction: discord.Interaction, + weekday: Optional[Weekday], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if fail_on_inactive_channel(session, interaction): + return + + schedule = get_user_schedule(session, interaction) + + if schedule is None or schedule.num_days_available() == 0: + await interaction.response.send_message( + ("You are already not subscribed to pair programming in #%s." % interaction.channel.name), + ephemeral=True + ) + return + + if fail_on_nonexistent_subscription(interaction, schedule, weekday): + return + + if weekday is not None: + schedule.set_availability_on(weekday, False) + await interaction.response.send_message( + ("Successfully unsubscribed from pair programming on %s in #%s." % weekday, interaction.channel.name), + ephemeral=True + ) + else: + schedule.set_availability_every_day(False) + await interaction.response.send_message( + ("Successfully unsubscribed from all pair programming in #%s." % interaction.channel.name), + ephemeral=True + ) + + +@app.command( + name="skip", + description="Mark yourself as unavailable for pair programming on some date in the future)." +) +@discord.app_commands.describe(human_date="A human-readable date like \"tomorrow\" or \"January 1\".") +@discord.app_commands.guild_only() +async def _skip( + interaction: discord.Interaction, + human_date: Optional[str], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if fail_on_inactive_channel(session, interaction): + return + + schedule = get_user_schedule(session, interaction) + + if fail_on_nonexistent_subscription(interaction, schedule, None): + return + + assert schedule is not None + + if human_date is None: + skipped_date = get_next_scheduled_date(schedule) + assert skipped_date is not None + else: + skipped_date = parse_human_readable_date(human_date) + if skipped_date is None: + await interaction.response.send_message( + "Could not parse date \"%s\"." % date, + ephemeral=True + ) + return + + if skipped_date < datetime.now().date(): + await interaction.response.send_message( + "Cannot skip a date in the past: %s." % format_date(skipped_date), + ephemeral=True + ) + return + + skipped_weekday = Weekday(skipped_date.weekday()) + + if fail_on_nonexistent_subscription(interaction, schedule, skipped_weekday): + return + + skip = get_skip(session, interaction, skipped_date) + + if skip is None: + skip = make_skip(session, interaction, skipped_date) + else: + await interaction.response.send_message( + "You already skipped pairing on %s." % format_date(skipped_date), + ephemeral=True + ) + return + + app.logger.info("Skipped pair programming on %s" % format_date(skipped_date)) + await interaction.response.send_message( + "Successfully skipped pair programming on %s" % format_date(skipped_date), + ephemeral=True + ) + + +@app.command( + name="unskip", + description="Unskip a skipped pairing session." +) +@discord.app_commands.describe(human_date="A human-readable date like \"tomorrow\" or \"January 1\".") +@discord.app_commands.guild_only() +async def _unskip( + interaction: discord.Interaction, + human_date: Optional[str], +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if fail_on_inactive_channel(session, interaction): + return + + schedule = get_user_schedule(session, interaction) + + if fail_on_nonexistent_subscription(interaction, schedule, None): + return + assert schedule is not None + + if human_date is None: + unskipped_date = get_next_scheduled_date(schedule) + assert unskipped_date is not None + else: + unskipped_date = parse_human_readable_date(human_date) + if unskipped_date is None: + await interaction.response.send_message( + "Could not parse date \"%s\"." % date, + ephemeral=True + ) + return + + if unskipped_date < datetime.now().date(): + await interaction.response.send_message( + "Cannot unskip a date in the past: %s." % format_date(unskipped_date), + ephemeral=True + ) + return + + unskipped_weekday = Weekday(unskipped_date.weekday()) + + if fail_on_nonexistent_subscription(interaction, schedule, unskipped_weekday): + return + + skip = get_skip(session, interaction, unskipped_date) + + if skip is None: + await interaction.response.send_message( + "You did not skip pair programming on %s." % format_date(unskipped_date), + ephemeral=True + ) + return + else: + session.delete(skip) + + app.logger.info("Unskipped pair programming on %s" % format_date(unskipped_date)) + await interaction.response.send_message( + "Successfully unskipped pair programming on %s" % format_date(unskipped_date), + ephemeral=True + ) + + +@app.command( + name="viewschedule", + description="View your pair programming schedule." +) +@discord.app_commands.guild_only() +async def _view_schedule( + interaction: discord.Interaction, +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + with Session.begin() as session: + if fail_on_inactive_channel(session, interaction): + return + + all_schedules = ( + session.query(Schedule) + .filter(Schedule.user_id == interaction.user.id) + .all() + ) + + all_skips = ( + session.query(Skip) + .filter(Skip.user_id == interaction.user.id) + .filter(sqlalchemy.func.DATE(Skip.date) >= datetime.now()) + .all() + ) + + grouped_schedules: dict[str, list[str]] = dict() + grouped_skips: dict[str, list[str]] = dict() + + for schedule in all_schedules: + channel = interaction.guild.get_channel(schedule.channel_id) + if channel is None: continue + channel_name = channel.name + grouped_schedules[channel_name] = list(map(str, schedule.days_available())) + + for skip in all_skips: + channel = interaction.guild.get_channel(skip.channel_id) + if channel is None: continue + channel_name = channel.name + if channel_name not in grouped_skips: + grouped_skips[channel_name] = [] + grouped_skips[channel_name].append(skip.date.strftime("%a %b %d")) + + if len(all_schedules) == 0 or sum([len(days) for days in grouped_schedules.values()]) == 0: + await interaction.response.send_message( + "You are not subscribed to pair programming in any channel.", + ephemeral=True + ) + return + + if len(grouped_schedules) == 1: + channel_name, weekdays = grouped_schedules.popitem() + if len(weekdays) == 7: + await interaction.response.send_message( + "You are subscribed to pair programming in #%s every day%s" % ( + channel_name, + " (skipping " + ", ".join(grouped_skips[channel_name]) + ")" if channel_name in grouped_skips else "" + ), + ephemeral=True + ) + else: + await interaction.response.send_message( + "You are subscribed to pair programming in #%s on %s%s" % ( + channel_name, + ", ".join(weekdays), + " (skipping " + ", ".join(grouped_skips[channel_name]) + ")" if channel_name in grouped_skips else "" + ), + ephemeral=True + ) + return + + response = "You are subscribed to pair programming in the following channels:\n" + for channel_name, weekdays in grouped_schedules.items(): + response += "* #%s: %s%s\n" % ( + channel_name, + ", ".join(weekdays), + "skipping " + ", ".join(grouped_skips[channel_name]) if channel_name in grouped_skips else "" + ) + await interaction.response.send_message(response, ephemeral=True) + + + +@app.command( + name="pairwith", + description="Start a pairing session with another channel member." +) +@discord.app_commands.guild_only() +async def _pair_with( + interaction: discord.Interaction, + user: discord.Member, +): + assert interaction.guild is not None + assert isinstance(interaction.channel, discord.TextChannel) + + if user.id == interaction.user.id: + await interaction.response.send_message( + f"You cannot pair with yourself.", + ephemeral=True + ) + return + + with Session.begin() as session: + if fail_on_inactive_channel(session, interaction): + return + + pairing = get_pairing(session, interaction, interaction.user, user) + + usernames = sorted([user.name for user in (interaction.user, user)]) + + if pairing is None: + # Make new discord thread + thread = await interaction.channel.create_thread( + name=f"{usernames[0]} & {usernames[1]}", + ) + + pairing = Pairing( + channel_id=interaction.channel_id, + user_1_id=interaction.user.id, + user_2_id=user.id, + ) + session.add(pairing) + else: + thread = interaction.guild.get_thread(pairing.thread_id) + if thread is None: + session.delete(pairing) + + # Make new discord thread + thread = await interaction.channel.create_thread( + name=f"{usernames[0]} & {usernames[1]}", + ) + + pairing = Pairing( + channel_id=interaction.channel_id, + user_1_id=interaction.user.id, + user_2_id=user.id, + ) + session.add(pairing) + + await thread.send( + f"<@{interaction.user.id}> has started a pairing session with you, <@{user.id}>. Happy pairing! :computer:" + ) + await interaction.response.send_message( + f"Successfully created pairing thread with <@{user.id}>", + ephemeral=True + ) + + +@discord.ext.tasks.loop(time=config.PAIRING_TIME) +async def make_groups(): + app.logger.info("Creating random groups.") + + with Session.begin() as session: + for guild in app.client.guilds: + channels = ( + session.query(PairingChannel) + .filter(PairingChannel.guild_id == guild.id) + .all() + ) + for pairing_channel in channels: + channel = app.client.get_channel(pairing_channel.channel_id) + if channel is None: + session.delete(pairing_channel) + app.logger.info("Pairing channel is gone :(") + continue + assert isinstance(channel, discord.TextChannel) + schedules = ( + session.query(Schedule) + .filter(Schedule.channel_id == channel.id) + .all() + ) + + todays_users = [ + guild.get_member(schedule.user_id) + for schedule in schedules + if schedule.is_available_on(Weekday(datetime.now().weekday())) + ] + + if len(todays_users) < 2: + app.logger.info("Not enough users to make groups.") + + random.shuffle(todays_users) + + groups = [] + if len(todays_users) % 2 == 1: + pair = ( + todays_users.pop(), + todays_users.pop(), + todays_users.pop(), + ) + groups.append(pair) + + while len(todays_users) > 0: + pair = ( + todays_users.pop(), + todays_users.pop(), + None, + ) + groups.append(pair) + + for group in groups: + usernames = sorted([user.name for user in group]) + user1, user2, user3 = group + + # TODO find existing pairing and get thread + pairing = Pairing( + channel_id=channel.id, + user_1_id=user1.id, + user_2_id=user2.id, + user_3_id=user3.id if user3 is not None else None, + ) + session.add(pairing) + + # Make new discord thread + thread = await channel.create_thread( + name=" & ".join(usernames) + ) + + await thread.send( + " ".join("<@%d>" % (user.id for user in group if user is not None)) + ) + await thread.send( + "You have been matched together. Happy pairing! :computer:" + ) + + app.logger.info("Send out pairings.") + + +def run(): + app.run(config.DISCORD_BOT_TOKEN) diff --git a/pairbot/models.py b/pairbot/models.py index b332273..e7794c0 100644 --- a/pairbot/models.py +++ b/pairbot/models.py @@ -6,10 +6,14 @@ import enum from datetime import datetime +from typing import ( + Optional, + Union, +) + from sqlalchemy import ( PrimaryKeyConstraint, Boolean, - Column, ) from sqlalchemy.orm import ( DeclarativeBase, @@ -32,9 +36,6 @@ class PairingChannel(Base): channel_id: Mapped[int] """The Discord channel ID.""" - active: Mapped[bool] - """Whether Pairbot is active in the channel.""" - leetcode_integration: Mapped[bool] """Whether leetcode integration is active in the channel or not.""" @@ -67,32 +68,43 @@ class Schedule(Base): user_id: Mapped[int] """The Discord user ID.""" - available_Monday: Mapped[bool] = mapped_column(Boolean, default=False) - available_Tuesday: Mapped[bool] = mapped_column(Boolean, default=False) - available_Wednesday: Mapped[bool] = mapped_column(Boolean, default=False) - available_Thursday: Mapped[bool] = mapped_column(Boolean, default=False) - available_Friday: Mapped[bool] = mapped_column(Boolean, default=False) - available_Saturday: Mapped[bool] = mapped_column(Boolean, default=False) - available_Sunday: Mapped[bool] = mapped_column(Boolean, default=False) + available_0: Mapped[bool] = mapped_column(Boolean, default=False) + available_1: Mapped[bool] = mapped_column(Boolean, default=False) + available_2: Mapped[bool] = mapped_column(Boolean, default=False) + available_3: Mapped[bool] = mapped_column(Boolean, default=False) + available_4: Mapped[bool] = mapped_column(Boolean, default=False) + available_5: Mapped[bool] = mapped_column(Boolean, default=False) + available_6: Mapped[bool] = mapped_column(Boolean, default=False) __table_args__ = ( PrimaryKeyConstraint("channel_id", "user_id"), ) - # hack to get Schedule[day_of_week] to work - def __getitem__(self, day_of_week: Weekday): - return getattr(self, f"available_{day_of_week}") + # Utility methods + + def is_available_on(self, day_of_week: Weekday) -> bool: + return getattr(self, f"available_{int(day_of_week)}") + + def set_availability_on(self, day_of_week: Weekday, value: bool): + setattr(self, f"available_{int(day_of_week)}", value) + + def set_availability_every_day(self, value: bool): + for day in Weekday: + self.set_availability_on(day, value) + + def days_available(self) -> list[Weekday]: + return [day for day in Weekday if self.is_available_on(day)] - def __setitem__(self, day_of_week: Weekday, value: bool): - setattr(self, f"available_{day_of_week}", value) + def days_unavailable(self) -> list[Weekday]: + return [day for day in Weekday if not self.is_available_on(day)] - @property - def days_available(self): - return [day for day in Weekday if self[day] == True] + def num_days_available(self) -> int: + return sum(1 for day in Weekday if self.is_available_on(day)) -class ScheduleAdjustment(Base): + +class Skip(Base): """Represents an adjustment to a user's availability on a specific date.""" - __tablename__ = "schedule_adjustment" + __tablename__ = "skip" channel_id: Mapped[int] """The Discord channel ID.""" @@ -103,16 +115,13 @@ class ScheduleAdjustment(Base): date: Mapped[datetime] """The date on which the user's availability is set.""" - available: Mapped[bool] - """Whether the user is available on this date.""" - __table_args__ = ( PrimaryKeyConstraint("channel_id", "user_id", "date"), ) -class Thread(Base): - """Represents a user's membership in a Discord thread created by Pairbot.""" +class Pairing(Base): + """Represents a Discord thread created by Pairbot.""" __tablename__ = "pairing" channel_id: Mapped[int] @@ -121,9 +130,15 @@ class Thread(Base): thread_id: Mapped[int] """The Discord thread ID""" - user_id: Mapped[int] - """The Discord user ID.""" + user_1_id: Mapped[int] + """The first Discord user ID.""" + + user_2_id: Mapped[int] + """The second Discord user ID.""" + + user_3_id: Mapped[Optional[int]] + """The third Discord user ID.""" __table_args__ = ( - PrimaryKeyConstraint("channel_id", "thread_id", "user_id"), + PrimaryKeyConstraint("channel_id", "thread_id"), ) From e16956ff03813b5209239484cbdb5809b420a326 Mon Sep 17 00:00:00 2001 From: Yaxel Date: Thu, 30 Nov 2023 23:57:54 -0800 Subject: [PATCH 3/3] :> --- pairbot/app.py | 2 +- pairbot/client.old.py | 384 ---------------------- pairbot/client.py | 748 ------------------------------------------ pairbot/context.py | 57 ---- pairbot/db.py | 133 -------- pairbot/discord.py | 90 ----- pairbot/globals.py | 6 - pairbot/main.py | 164 ++++++--- pairbot/utils.py | 25 -- 9 files changed, 113 insertions(+), 1496 deletions(-) delete mode 100644 pairbot/client.old.py delete mode 100644 pairbot/client.py delete mode 100644 pairbot/context.py delete mode 100644 pairbot/db.py delete mode 100644 pairbot/discord.py delete mode 100644 pairbot/globals.py delete mode 100644 pairbot/utils.py diff --git a/pairbot/app.py b/pairbot/app.py index 4348ff0..fe32b47 100644 --- a/pairbot/app.py +++ b/pairbot/app.py @@ -85,9 +85,9 @@ async def wrapper(i: discord.Interaction[Any], *args: P.args, **kwargs: P.kwargs " }" if len(kwargs) > 0 else " with no arguments" ) - self.logger.info("Executing slash command /%s%s", i.command.name, pretty_kwargs) try: with InteractionContext(i): + self.logger.info("Executing slash command /%s%s", i.command.name, pretty_kwargs) await callback(i, *args, **kwargs) except Exception as e: self.logger.error(e, exc_info=True) diff --git a/pairbot/client.old.py b/pairbot/client.old.py deleted file mode 100644 index 7b77418..0000000 --- a/pairbot/client.old.py +++ /dev/null @@ -1,384 +0,0 @@ -import json -import logging -import os -import random -import sqlite3 -from datetime import datetime -from pathlib import Path -from typing import List - -import discord -from discord import app_commands -from discord.ext import tasks -from dotenv import load_dotenv - -from .db import PairingsDB, ScheduleDB, Timeblock -from .utils import get_user_name, parse_args, read_guild_to_channel - -load_dotenv() -args = parse_args() -if args.dev: - print("Running in dev mode.") - BOT_TOKEN = os.getenv("BOT_TOKEN_DEV") - DATA_DIR = "data" - GUILDS_PATH = f"{DATA_DIR}/guilds-dev.json" - SCHEDULE_DB_PATH = f"{DATA_DIR}/schedule-dev.db" - PAIRINGS_DB_PATH = f"{DATA_DIR}/pairings-dev.db" - LOG_FILE = "pairbot-dev.log" -else: - print("Running in prod mode.") - BOT_TOKEN = os.getenv("BOT_TOKEN") - DATA_DIR = "data" - GUILDS_PATH = f"{DATA_DIR}/guilds.json" - SCHEDULE_DB_PATH = f"{DATA_DIR}/schedule.db" - PAIRINGS_DB_PATH = f"{DATA_DIR}/pairings.db" - LOG_FILE = "pairbot.log" -SORRY = "Unexpected error." -Path(DATA_DIR).mkdir(parents=True, exist_ok=True) - -logging.basicConfig( - filename=LOG_FILE, - filemode="a", - format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", - datefmt="%H:%M:%S", - level=logging.DEBUG, -) -logger = logging.getLogger("pairbot") - -intents = discord.Intents.default() -intents.members = True -intents.message_content = True -client = discord.Client(intents=intents) -tree = app_commands.CommandTree(client) -db = ScheduleDB(SCHEDULE_DB_PATH) -pairings_db = PairingsDB(PAIRINGS_DB_PATH) - - -# Discord API currently doesn't support variadic arguments -# https://github.com/discord/discord-api-docs/discussions/3286 -@tree.command( - name="subscribe", - description="Add timeblocks to find a partner for pair programming. \ -Matches go out at 8am UTC that day.", -) -@app_commands.describe( - timeblock="Choose WEEK to get a partner for the whole week (pairs announced Monday UTC)." -) -@app_commands.choices( - timeblock=[ - app_commands.Choice(name=Timeblock.WEEK.name, value=Timeblock.WEEK.value), - app_commands.Choice(name=Timeblock.Monday.name, value=Timeblock.Monday.value), - app_commands.Choice(name=Timeblock.Tuesday.name, value=Timeblock.Tuesday.value), - app_commands.Choice( - name=Timeblock.Wednesday.name, value=Timeblock.Wednesday.value - ), - app_commands.Choice( - name=Timeblock.Thursday.name, value=Timeblock.Thursday.value - ), - app_commands.Choice(name=Timeblock.Friday.name, value=Timeblock.Friday.value), - app_commands.Choice( - name=Timeblock.Saturday.name, value=Timeblock.Saturday.value - ), - app_commands.Choice(name=Timeblock.Sunday.name, value=Timeblock.Sunday.value), - ] -) -async def _subscribe(interaction: discord.Interaction, timeblock: Timeblock): - try: - db.insert(interaction.guild_id, interaction.user.id, timeblock) - timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} subscribed T:{timeblock.name}." - ) - msg = ( - f"Your new schedule is `{Timeblock.generate_schedule(timeblocks)}`. " - f"You can call `/subscribe` again to sign up for more days." - ) - await interaction.response.send_message(msg, ephemeral=True) - except sqlite3.IntegrityError as e: - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} failed subscribe T:{timeblock.name}." - ) - logger.warning(e, exc_info=True) - msg = ( - f"You are already subscribed to {timeblock}. " - f"Call `/unsubscribe` to remove a subscription or `/schedule` to view your schedule." - ) - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command(name="unsubscribe", description="Remove timeblocks for pair programming.") -@app_commands.describe(timeblock="Call `/unsubscribe-all` to remove all timeblocks.") -@app_commands.choices( - timeblock=[ - app_commands.Choice(name=Timeblock.WEEK.name, value=Timeblock.WEEK.value), - app_commands.Choice(name=Timeblock.Monday.name, value=Timeblock.Monday.value), - app_commands.Choice(name=Timeblock.Tuesday.name, value=Timeblock.Tuesday.value), - app_commands.Choice( - name=Timeblock.Wednesday.name, value=Timeblock.Wednesday.value - ), - app_commands.Choice( - name=Timeblock.Thursday.name, value=Timeblock.Thursday.value - ), - app_commands.Choice(name=Timeblock.Friday.name, value=Timeblock.Friday.value), - app_commands.Choice( - name=Timeblock.Saturday.name, value=Timeblock.Saturday.value - ), - app_commands.Choice(name=Timeblock.Sunday.name, value=Timeblock.Sunday.value), - ] -) -async def _unsubscribe(interaction: discord.Interaction, timeblock: Timeblock): - try: - db.delete(interaction.guild_id, interaction.user.id, timeblock) - timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} unsubscribed T:{timeblock.name}." - ) - msg = f"Your new schedule is `{Timeblock.generate_schedule(timeblocks)}`." - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command( - name="unsubscribe-all", description="Remove all timeblocks for pair programming." -) -async def _unsubscribe_all(interaction: discord.Interaction): - try: - db.unsubscribe(interaction.guild_id, interaction.user.id) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} called unsubscribe-all." - ) - msg = "Your pairing subscriptions have been removed. To rejoin, call `/subscribe` again." - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command(name="schedule", description="View your pairing schedule.") -async def _schedule(interaction: discord.Interaction): - try: - timeblocks = db.query_userid(interaction.guild_id, interaction.user.id) - schedule = Timeblock.generate_schedule(timeblocks) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} queried schedule {schedule}." - ) - msg = ( - f"Your current schedule is `{schedule}`. " - "You can call `/subscribe` or `/unsubscribe` to modify it." - ) - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command( - name="set-channel", description="Set a channel for bot messages (admin only)." -) -@app_commands.checks.has_permissions(administrator=True) -async def _set_channel(interaction: discord.Interaction, channel: discord.TextChannel): - try: - guild_to_channel = read_guild_to_channel(GUILDS_PATH) - guild_to_channel[str(interaction.guild_id)] = channel.id - with open(GUILDS_PATH, "w") as f: - json.dump(guild_to_channel, f) - logger.info( - f"G:{interaction.guild_id} U:{interaction.user.id} set-channel C:{channel.id}." - ) - msg = f"Successfully set bot channel to `{channel.name}`." - await interaction.response.send_message(msg, ephemeral=True) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -@tree.command( - name="pairwith", - description="Start an immediate pairing session with another member.", -) -async def _pairwith(interaction: discord.Interaction, user: discord.Member): - try: - guild_to_channel = read_guild_to_channel(GUILDS_PATH) - channel_id = guild_to_channel[str(interaction.guild_id)] - channel = client.get_channel(channel_id) - users = [interaction.user, user] - notify_msg = ( - f"<@{interaction.user.id}> has started an on-demand pair with you, <@{user.id}>. " - "Happy pairing! :computer:" - ) - await create_group_thread(interaction.guild_id, users, channel, notify_msg) - logger.info( - f"G:{interaction.guild_id} C:{channel.id} on-demand paired U:{interaction.user.id} with {user.id}." - ) - await interaction.response.send_message( - f"Thread with {get_user_name(user)} created in channel `{channel.name}`.", - ephemeral=True, - ) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message(SORRY, ephemeral=True) - - -async def dm_user(user: discord.User, msg: str): - try: - channel = await user.create_dm() - await channel.send(msg) - except Exception as e: - logger.error(e, exc_info=True) - - -async def create_group_thread( - guild_id: int, - users: List[discord.User], - channel: discord.TextChannel, - notify_msg: str, -): - # @ notifying users in a private thread invites them - # so `notify_msg` must notify for this to work - userids = [user.id for user in users] - thread_id = pairings_db.query_userids(guild_id, userids, channel.id) - thread = None - if thread_id is not None: - logger.debug(f"Found existing thread {thread_id} for G:{guild_id} U:{userids}") - try: - guild = client.get_guild(guild_id) - thread = await guild.fetch_channel(thread_id) - except discord.errors.NotFound: - logger.debug(f"Couldn't fetch thread {thread_id}, maybe deleted?") - pairings_db.delete(guild_id, userids, channel.id, thread_id) - if thread is None: - title = ", ".join(get_user_name(user) for user in users) - thread = await channel.create_thread( - name=f"{title}", auto_archive_duration=10080 - ) - logger.debug(f"Created new thread {thread.id} for G:{guild_id} U:{userids}") - pairings_db.insert(guild_id, userids, channel.id, thread.id) - else: - logger.debug(f"Found existing thread {thread_id} for G:{guild_id} U:{userids}") - guild = client.get_guild(guild_id) - thread = await guild.fetch_channel(thread_id) - await thread.send(notify_msg) - - -async def on_tree_error( - interaction: discord.Interaction, error: app_commands.AppCommandError -): - if isinstance(error, app_commands.CommandOnCooldown): - return await interaction.response.send_message( - f"Command is currently on cooldown, try again in {error.retry_after:.2f} seconds.", - ephemeral=True, - ) - elif isinstance(error, app_commands.MissingPermissions): - return await interaction.response.send_message( - "You don't have the permissions to do that.", ephemeral=True - ) - else: - raise error - - -@tasks.loop(hours=1) -async def pairing_cron(): - def should_run(): - now = datetime.utcnow() - hour = now.time().hour - logger.debug(f"Checking pairing job at UTC:{now}.") - return hour == 8 - - if should_run(): - await run_pairing() - - -async def run_pairing(): - now = datetime.utcnow() - print(now) - logger.debug(f"Running pairing job at UTC:{now}.") - weekday = now.weekday() - weekday_map = { - 0: Timeblock.Monday, - 1: Timeblock.Tuesday, - 2: Timeblock.Wednesday, - 3: Timeblock.Thursday, - 4: Timeblock.Friday, - 5: Timeblock.Saturday, - 6: Timeblock.Sunday, - } - timeblock = weekday_map[weekday] - for guild in client.guilds: - await pair(guild.id, timeblock) - # weekly Monday match - if weekday == 0: - await pair(guild.id, Timeblock.WEEK) - - -async def pair(guild_id: int, timeblock: Timeblock): - try: - userids = db.query_timeblock(guild_id, timeblock) - users = [client.get_user(userid) for userid in userids] - # Users may leave the server without unsubscribing - # TODO: listen to that event and drop them from the table - users = list(filter(None, users)) - logger.info( - f"Pairing for G:{guild_id} T:{timeblock.name} with {len(users)}/{len(userids)} users." - ) - guild_to_channel = read_guild_to_channel(GUILDS_PATH) - channel = client.get_channel(guild_to_channel[str(guild_id)]) - if len(users) < 2: - for user in users: - logger.info( - f"G:{guild_id} T:{timeblock.name} pair failed, dming U:{user.id}." - ) - msg = ( - f"Thanks for signing up for pairing this {timeblock}. " - "Unfortunately, there was nobody else available this time." - ) - await dm_user(user, msg) - await channel.send( - f"Not enough signups this {timeblock}. Try `/subscribe` to sign up!" - ) - return - - random.shuffle(users) - groups = [users[i :: len(users) // 2] for i in range(len(users) // 2)] - for group in groups: - notify_msg = ", ".join(f"<@{user.id}>" for user in group) - notify_msg = f"{notify_msg}: you've been matched together for this {timeblock}. Happy pairing! :computer:" - await create_group_thread(guild_id, group, channel, notify_msg) - logger.info( - f"G:{guild_id} C:{channel.id} paired U:{[user.id for user in group]}." - ) - await channel.send( - f"Pairings for {len(users)} users have been sent out for this {timeblock}. Try `/subscribe` to sign up!" - ) - except Exception as e: - logger.error(e, exc_info=True) - - -def local_setup(): - try: - read_guild_to_channel(GUILDS_PATH) - except Exception: - with open(GUILDS_PATH, "w") as f: - json.dump({}, f) - - -@client.event -async def on_ready(): - local_setup() - await client.wait_until_ready() - tree.on_error = on_tree_error - for guild in client.guilds: - tree.copy_global_to(guild=guild) - await tree.sync(guild=guild) - print("Code sync complete!") - pairing_cron.start() - print("Starting cron loop...") - logger.info("Bot started.") - - -def run(): - client.run(BOT_TOKEN) diff --git a/pairbot/client.py b/pairbot/client.py deleted file mode 100644 index 5f93540..0000000 --- a/pairbot/client.py +++ /dev/null @@ -1,748 +0,0 @@ -"""client.py - -This module describes the Discord slash command interface and corresponding logic. -""" - -import functools -import logging - -from typing import ( - Any, - Optional, - TypeVar, - ParamSpec, - Callable, - Concatenate, - Coroutine, -) - -import sqlalchemy -from sqlalchemy import ( - create_engine, - func, - or_, -) -import sqlalchemy.orm -from sqlalchemy.orm import ( - Session, - sessionmaker, - -) - -import discord -import discord.ext.commands -import discord.ext.tasks - -from datetime import datetime, timedelta -import dateparser - -from . import ( - config, -) - -from .models import ( - Weekday, - PairingChannel, - Schedule, - ScheduleAdjustment, - Pairing, -) - -logger = logging.getLogger(__name__) - -# type fuckery for the command decorator -T = TypeVar("T") -P = ParamSpec("P") -CommandCallback = Callable[Concatenate[discord.Interaction[Any], P], Coroutine[Any, Any, T]] - -class Pairbot(discord.Client): - """Represents a running instance of Pairbot.""" - - def __init__(self, intents: discord.Intents, **options: Any) -> None: - super().__init__(intents=intents, **options) - self.tree = discord.app_commands.CommandTree(self) - - self.db_engine = create_engine(config.DATABASE_URL) - self.make_orm_session = sessionmaker(self.db_engine) - - def command( - self, - **options: Any - ): - """Wrapper for pairbot slash commands with logging and error-handling.""" - def decorator(callback: CommandCallback): - @self.tree.command(**options) - @discord.app_commands.guild_only() - @functools.wraps(callback) - async def wrapper( - interaction: discord.Interaction[Any], - *args: P.args, - **kwargs: P.kwargs, - ) -> None: - # Keep type checker happy - assert interaction.command is not None - assert interaction.guild is not None - assert isinstance(interaction.channel, discord.TextChannel) - assert isinstance(interaction.user, discord.Member) - - # Log command execution - pretty_kwargs = ( - " with arguments { " + - ", ".join((f"{key}=\"{str(value)}\"" for key, value in kwargs.items())) + - " }" - if len(kwargs) > 0 else " with no arguments" - ) - logger.info( - f"User \"{interaction.user.name}\" executed command /{interaction.command.name}{pretty_kwargs} in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." - ) - - try: - await callback(interaction, *args, **kwargs) - except Exception as e: - logger.error(e, exc_info=True) - await interaction.response.send_message("Pairbot broke somehow! :v", ephemeral=True) - - return wrapper - return decorator - - async def on_ready(self) -> None: - for guild in self.guilds: - logger.info(f"Copying command tree to guild \"{guild.name}\"") - self.tree.copy_global_to(guild=guild) - await self.tree.sync(guild=guild) - logger.info("Pairbot ready.") - - -# Instantiate client and register slash commands -intents = discord.Intents.all() -client = Pairbot( - intents=intents, -) - - -@client.command( - name="addpairbot", - description="Add Pairbot to the current channel." -) -@discord.app_commands.checks.has_permissions(administrator=True) -async def _add_pairbot(interaction: discord.Interaction): - assert interaction.guild is not None - assert isinstance(interaction.channel, discord.TextChannel) - - session = client.make_orm_session() - with session.begin(): - channel = ( - session.query(PairingChannel) - .filter(PairingChannel.channel_id == interaction.channel_id) - .one_or_none() - ) - - if channel is not None: - if channel.active: - logger.info( - f"Pairbot is already added to guild \"{interaction.guild.name}\", channel \"#{interaction.channel}\"." - ) - await interaction.response.send_message( - f"Pairbot is already added to \"#{interaction.channel.name}\"." - ) - return - else: - channel.active = True - else: - channel = PairingChannel( - guild_id = interaction.guild.id, - channel_id = interaction.channel.id, - active = True, - leetcode_integration = False, # TODO - ) - session.add(channel) - - session.commit() - - logger.info( - f"Added Pairbot to guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." - ) - await interaction.response.send_message( - f"Added Pairbot to \"#{interaction.channel.name}\"." - ) - - -@client.command( - name="removepairbot", - description="Remove Pairbot from the current channel." -) -@discord.app_commands.guild_only() -@discord.app_commands.checks.has_permissions(administrator=True) -async def _remove_pairbot(interaction: discord.Interaction): - assert interaction.guild is not None - assert isinstance(interaction.channel, discord.TextChannel) - - session = client.make_orm_session() - with session.begin(): - channel = ( - session.query(PairingChannel) - .filter(PairingChannel.channel_id == interaction.channel_id) - .one_or_none() - ) - - if channel is None or not channel.active: - logger.info( - f"Pairbot is not added to guild \"{interaction.guild.name}\", channel \"#{interaction.channel}\"." - ) - await interaction.response.send_message( - f"Pairbot is not added to \"#{interaction.channel.name}\"." - ) - else: - channel.active = False - session.commit() - logger.info( - f"Removed Pairbot from guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." - ) - await interaction.response.send_message( - f"Removed Pairbot from \"#{interaction.channel.name}\"." - ) - - -@client.command( - name="subscribe", - description="Subscribe to pair programming (every day if no weekday specified)." -) -@discord.app_commands.guild_only() -async def _subscribe( - interaction: discord.Interaction, - weekday: Optional[Weekday], -): - assert interaction.guild is not None - assert isinstance(interaction.channel, discord.TextChannel) - session = client.make_orm_session() - with session.begin(): - channel = ( - session.query(PairingChannel) - .filter(PairingChannel.channel_id == interaction.channel_id) - .one_or_none() - ) - if channel is None: - await interaction.response.send_message( - f"Pairbot is not active in this channel.", - ephemeral=True - ) - return - - schedule = ( - session.query(Schedule) - .filter(Schedule.channel_id == interaction.channel_id) - .filter(Schedule.user_id == interaction.user.id) - .one_or_none() - ) - - if schedule is None: - schedule = Schedule( - channel_id = interaction.channel_id, - user_id = interaction.user.id, - ) - session.add(schedule) - - if weekday is not None: - if schedule[weekday] == True: - await interaction.response.send_message( - f"You are already subscribed to pair programming on {str(weekday)} in #{interaction.channel.name}.", - ephemeral=True - ) - return - else: - schedule[weekday] = True - - session.commit() - - logger.info( - f"Subscribed user \"{interaction.user.name}\" to pair programming on {str(weekday)} in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." - ) - - msg = f"Successfully subscribed to pair programming on {str(weekday)} in #{interaction.channel.name}." - else: - if len(schedule.days_available) == 7: - await interaction.response.send_message( - f"You are already subscribed to pair programming every day.", - ephemeral=True - ) - return - for day in Weekday: - schedule[day] = True - - session.commit() - - logger.info( - f"Subscribed user \"{interaction.user.name}\" to daily pair programming in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." - ) - - msg = f"Successfully subscribed to daily pair programming in #{interaction.channel.name}." - await interaction.response.send_message(msg, ephemeral=True) - - -@client.command( - name="unsubscribe", - description="Unsubscribe from pair programming (every day if no weekday specified)." -) -@discord.app_commands.guild_only() -async def _unsubscribe( - interaction: discord.Interaction, - weekday: Optional[Weekday], -): - assert interaction.guild is not None - assert isinstance(interaction.channel, discord.TextChannel) - - session = client.make_orm_session() - with session.begin(): - channel = ( - session.query(PairingChannel) - .filter(PairingChannel.channel_id == interaction.channel_id) - .one_or_none() - ) - if channel is None: - await interaction.response.send_message( - f"Pairbot is not active in this channel.", - ephemeral=True - ) - return - - schedule = ( - session.query(Schedule) - .filter(Schedule.channel_id == interaction.channel_id) - .filter(Schedule.user_id == interaction.user.id) - .one_or_none() - ) - - if schedule is None: - await interaction.response.send_message( - f"You are already not subscribed to pair programming in #{interaction.channel.name}.", - ephemeral=True - ) - return - - if weekday is not None: - if not schedule[weekday]: - await interaction.response.send_message( - f"You are already not subscribed to pair programming on {str(weekday)} in #{interaction.channel.name}.", - ephemeral=True - ) - return - - schedule[weekday] = False - session.commit() - - logger.info( - f"Unsubscribed user \"{interaction.user.name}\" from pair programming on {str(weekday)} in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." - ) - - msg = f"Successfully unsubscribed from pair programming on {str(weekday)} in #{interaction.channel.name}." - else: - if len(schedule.days_available) == 0: - await interaction.response.send_message( - f"You are already not subscribed to pair programming.", - ephemeral=True - ) - return - for day in Weekday: - schedule[day] = False - session.commit() - - logger.info( - f"Unsubscribed user \"{interaction.user.name}\" from all pair programming in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." - ) - - msg = f"Successfully unsubscribed from all pair programming in #{interaction.channel.name}." - - await interaction.response.send_message(msg, ephemeral=True) - - -@client.command( - name="skip", - description="Mark yourself as unavailable for pair programming on some date in the future)." -) -@discord.app_commands.describe(human_date="A human-readable date like \"tomorrow\" or \"January 1\".") -@discord.app_commands.guild_only() -async def _skip( - interaction: discord.Interaction, - human_date: Optional[str], -): - assert interaction.guild is not None - assert isinstance(interaction.channel, discord.TextChannel) - - session = client.make_orm_session() - with session.begin(): - channel = ( - session.query(PairingChannel) - .filter(PairingChannel.channel_id == interaction.channel_id) - .one_or_none() - ) - if channel is None: - await interaction.response.send_message( - f"Pairbot is not active in this channel.", - ephemeral=True - ) - return - - schedule = ( - session.query(Schedule) - .filter(Schedule.channel_id == interaction.channel_id) - .filter(Schedule.user_id == interaction.user.id) - .one_or_none() - ) - - if schedule is None or len(schedule.days_available) == 0: - await interaction.response.send_message( - f"You are not subscribed to pair programming in #{interaction.channel.name}.", - ephemeral=True - ) - return - - if human_date is None: - adjustment_date = None - current_weekday = datetime.now().weekday() - for i in range(0, 7): - weekday = Weekday((current_weekday + i) % 7) - if schedule[weekday]: - adjustment_date = datetime.now().date() + timedelta(days=i) - break - - assert adjustment_date is not None - else: - # Clean things up for the parser - cleaned_date = human_date.lower() - cleaned_date = cleaned_date.replace("next", "") - - # Try to parse the date - adjustment_datetime = dateparser.parse( - cleaned_date, - settings={ - "PREFER_DATES_FROM": "future", - "RELATIVE_BASE": datetime.now(), - }, - languages=["en"], - ) - if adjustment_datetime is None: - await interaction.response.send_message( - f"Could not parse date \"{human_date}\".", - ephemeral=True - ) - return - - adjustment_date = adjustment_datetime.date() - if adjustment_datetime < datetime.now(): - await interaction.response.send_message( - f"Cannot skip a date in the past: {adjustment_date.strftime('%A %B %d, %Y')}.", - ephemeral=True - ) - return - - weekday = Weekday(adjustment_date.weekday()) - if schedule[weekday] == False: - await interaction.response.send_message( - f"You are not subscribed to pair programming in #{interaction.channel.name} on {weekday}.", - ephemeral=True - ) - return - - adjustment = ( - session.query(ScheduleAdjustment) - .filter(ScheduleAdjustment.channel_id == interaction.channel_id) - .filter(ScheduleAdjustment.user_id == interaction.user.id) - .filter(func.DATE(ScheduleAdjustment.date) == adjustment_date) - .one_or_none() - ) - - if adjustment is None: - adjustment = ScheduleAdjustment( - channel_id=interaction.channel_id, - user_id=interaction.user.id, - date=adjustment_date, - available=False, - ) - session.add(adjustment) - else: - if adjustment.available == False: - await interaction.response.send_message( - f"You already skipped pairing on {adjustment_date.strftime('%A %B %d, %Y')}.", - ephemeral=True - ) - return - adjustment.available = False - session.commit() - - logger.info( - f"Skipped pair programming on {adjustment_date.strftime('%A %B %d, %Y')} for user \"{interaction.user.name}\" in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." - ) - - msg = f"Successfully skipped pair programming on {adjustment_date.strftime('%A %B %d, %Y')}." - await interaction.response.send_message(msg, ephemeral=True) - - -@client.command( - name="unskip", - description="Mark yourself as available for pair programming on some date in the future." -) -@discord.app_commands.describe(human_date="A human-readable date like \"tomorrow\" or \"January 1\".") -@discord.app_commands.guild_only() -async def _unskip( - interaction: discord.Interaction, - human_date: Optional[str], -): - assert interaction.guild is not None - assert isinstance(interaction.channel, discord.TextChannel) - - session = client.make_orm_session() - with session.begin(): - channel = ( - session.query(PairingChannel) - .filter(PairingChannel.channel_id == interaction.channel_id) - .one_or_none() - ) - if channel is None: - await interaction.response.send_message( - f"Pairbot is not active in this channel.", - ephemeral=True - ) - return - - schedule = ( - session.query(Schedule) - .filter(Schedule.channel_id == interaction.channel_id) - .filter(Schedule.user_id == interaction.user.id) - .one_or_none() - ) - - if schedule is None or len(schedule.days_available) == 0: - await interaction.response.send_message( - f"You are not subscribed to pair programming in #{interaction.channel.name}.", - ephemeral=True - ) - return - - if human_date is None: - adjustment_date = None - current_weekday = datetime.now().weekday() - for i in range(0, 7): - weekday = Weekday((current_weekday + i) % 7) - if schedule[weekday]: - adjustment_date = datetime.now().date() + timedelta(days=i) - break - - assert adjustment_date is not None - else: - # Clean things up for the parser - cleaned_date = human_date.lower() - cleaned_date = cleaned_date.replace("next", "") - - # Try to parse the date - adjustment_datetime = dateparser.parse( - cleaned_date, - settings={ - "PREFER_DATES_FROM": "future", - "RELATIVE_BASE": datetime.now(), - }, - languages=["en"], - ) - if adjustment_datetime is None: - await interaction.response.send_message( - f"Could not parse date \"{human_date}\".", - ephemeral=True - ) - return - - adjustment_date = adjustment_datetime.date() - if adjustment_datetime < datetime.now(): - await interaction.response.send_message( - f"Cannot unskip a date in the past: {adjustment_date.strftime('%A %B %d, %Y')}.", - ephemeral=True - ) - return - - weekday = Weekday(adjustment_date.weekday()) - if schedule[weekday] == False: - await interaction.response.send_message( - f"You are not subscribed to pair programming in #{interaction.channel.name} on {weekday}.", - ephemeral=True - ) - return - - adjustment = ( - session.query(ScheduleAdjustment) - .filter(ScheduleAdjustment.channel_id == interaction.channel_id) - .filter(ScheduleAdjustment.user_id == interaction.user.id) - .filter(func.DATE(ScheduleAdjustment.date) == adjustment_date) - .one_or_none() - ) - - if adjustment is None: - adjustment = ScheduleAdjustment( - channel_id=interaction.channel_id, - user_id=interaction.user.id, - date=adjustment_date, - available=True, - ) - session.add(adjustment) - else: - if adjustment.available == True: - await interaction.response.send_message( - f"You already unskipped pairing on {adjustment_date.strftime('%A %B %d, %Y')}.", - ephemeral=True - ) - return - adjustment.available = True - session.commit() - - logger.info( - f"Unskipped pair programming on {adjustment_date.strftime('%A %B %d, %Y')} for user \"{interaction.user.name}\" in guild \"{interaction.guild.name}\", channel \"#{interaction.channel.name}\"." - ) - - msg = f"Successfully unskipped pair programming on {adjustment_date.strftime('%A %B %d, %Y')}." - await interaction.response.send_message(msg, ephemeral=True) - - -@client.command( - name="viewschedule", - description="View your pair programming schedule." -) -@discord.app_commands.guild_only() -async def _view_schedule( - interaction: discord.Interaction, -): - assert interaction.guild is not None - - session = client.make_orm_session() - with session.begin(): - schedules = ( - session.query(Schedule) - .filter(Schedule.user_id == interaction.user.id) - .all() - ) - - adjustments = ( - session.query(ScheduleAdjustment) - .filter(ScheduleAdjustment.user_id == interaction.user.id) - .all() - ) - - guild_schedule = { - interaction.guild.get_channel(schedule.channel_id): [str(day) for day in schedule.days_available] - for schedule in schedules - } - - skipped = dict() - unskipped = dict() - - for adjustment in adjustments: - channel_name = interaction.guild.get_channel(adjustment.channel_id) - if adjustment.available == False: - if channel_name not in skipped: - skipped[channel_name] = [] - skipped[channel_name].append(adjustment.date.strftime("%a %b %d")) - if adjustment.available == True: - if channel_name not in unskipped: - unskipped[channel_name] = [] - unskipped[channel_name].append(adjustment.date.strftime("%a %b %d")) - - if len(guild_schedule) == 0 or sum([len(days) for days in guild_schedule.values()]) == 0: - msg = "You are not subscribed to pair programming." - elif len(guild_schedule) == 1: - channel_name, schedule = guild_schedule.popitem() - msg = f"You are subscribed to pair programming in #{channel_name} on {', '.join(schedule)}" - if channel_name in skipped and len(skipped[channel_name]) > 0: - msg += f" (skipping {', '.join(skipped[channel_name])})" - else: - msg = "You are subscribed to pair programming in the following channels:\n" - for channel_name, schedule in guild_schedule.items(): - msg += f"* #{channel_name}: {', '.join(schedule)}" - if channel_name in skipped and len(skipped[channel_name]) > 0: - msg += f" (skipping {', '.join(skipped[channel_name])})" - msg += "\n" - - await interaction.response.send_message(msg, ephemeral=True) - - -@client.command( - name="pairwith", - description="Start a pairing session with another channel member." -) -@discord.app_commands.guild_only() -async def _pair_with( - interaction: discord.Interaction, - user: discord.Member, -): - assert interaction.guild is not None - assert isinstance(interaction.channel, discord.TextChannel) - - if user.id == interaction.user.id: - await interaction.response.send_message( - f"You cannot pair with yourself.", - ephemeral=True - ) - return - - session = client.make_orm_session() - with session.begin(): - pairing = ( - session.query(Pairing) - .filter(or_( - Pairing.user_1_id == interaction.user.id, - Pairing.user_2_id == interaction.user.id, - )) - .filter(or_( - Pairing.user_1_id == user.id, - Pairing.user_2_id == user.id, - )) - .one_or_none() - ) - - usernames = sorted([user.name for user in (interaction.user, user)]) - - if pairing is None: - # Make new discord thread - thread = await interaction.channel.create_thread( - name=f"{usernames[0]} & {usernames[1]}", - ) - - pairing = Pairing( - channel_id=interaction.channel.id, - thread_id=thread.id, - user_1_id=interaction.user.id, - user_2_id=user.id, - ) - session.add(pairing) - session.commit() - else: - thread = interaction.guild.get_thread(pairing.thread_id) - if thread is None: - session.delete(pairing) - session.commit() - - # Make new discord thread - thread = await interaction.channel.create_thread( - name=f"{usernames[0]} & {usernames[1]}", - ) - - pairing = Pairing( - channel_id=interaction.channel.id, - thread_id=thread.id, - user_1_id=interaction.user.id, - user_2_id=user.id, - ) - session.add(pairing) - session.commit() - - await thread.send( - f"<@{interaction.user.id}> has started a pairing session with you, <@{user.id}>. Happy pairing! :computer:" - ) - await interaction.response.send_message( - f"Successfully created pairing thread with <@{user.id}>", - ephemeral=True - ) - - -@discord.ext.tasks.loop(time=config.PAIRING_TIME) -async def make_groups(): - pass - - -def run(): - client.run(config.DISCORD_BOT_TOKEN) diff --git a/pairbot/context.py b/pairbot/context.py deleted file mode 100644 index f5797c0..0000000 --- a/pairbot/context.py +++ /dev/null @@ -1,57 +0,0 @@ - -import contextvars -from types import TracebackType -from typing import ( - Any, - Self, -) - -from discord import Interaction - -from . import globals -from .app import App - -class AppContext: - """The app context contains application-specific information (similar to Flask's application - context).""" - - def __init__( - self, - app: App, - ): - self.app = app - self._cv_tokens: list[contextvars.Token] = [] - - def push(self) -> None: - """Binds the app context to the current context.""" - self._cv_tokens.append(globals.app.set(self)) - - def pop(self, exc: BaseException | None): - ctx = globals.app.get() - assert ctx is self - globals.app.reset(self._cv_tokens.pop()) - - def __enter__(self) -> Self: - self.push() - return self - - def __exit__( - self, - exc_type: type | None, - exc_value: BaseException | None, - tb: TracebackType | None - ) -> None: - self.pop(exc_value) - - -class InteractionContext: - """The interaction context contains interaction-specific information (similar to Flask's request - context). It is created and pushed at the beginning of an interaction, and popped at the end.""" - - def __init__( - self, - app: App, - interaction: Interaction[Any], - ): - self.app = app - self.interaction = interaction diff --git a/pairbot/db.py b/pairbot/db.py deleted file mode 100644 index a87df35..0000000 --- a/pairbot/db.py +++ /dev/null @@ -1,133 +0,0 @@ -import sqlite3 -from contextlib import closing -from enum import Enum, auto -from typing import List, Optional - - -class Timeblock(Enum): - WEEK = auto() - Monday = auto() - Tuesday = auto() - Wednesday = auto() - Thursday = auto() - Friday = auto() - Saturday = auto() - Sunday = auto() - - def __str__(self): - return self.name - - @staticmethod - def generate_schedule(timeblocks: List["Timeblock"]) -> str: - return f"{[str(block) for block in sorted(timeblocks, key=lambda block: block.value)]}" - - -class ScheduleDB: - def __init__(self, path: str) -> None: - self.db = path - self.con = sqlite3.connect(self.db) - self._setup() - - def _setup(self) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "CREATE TABLE IF NOT EXISTS users (guildid INTEGER, userid INTEGER, timeblock INTEGER, " - "UNIQUE (guildid, userid, timeblock))" - ) - self.con.commit() - - def insert(self, guild_id: int, user_id: int, timeblock: Timeblock) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "INSERT INTO users VALUES (?, ?, ?)", - (guild_id, user_id, timeblock.value), - ) - self.con.commit() - - def delete(self, guild_id: int, user_id: int, timeblock: Timeblock) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "DELETE from users WHERE guildid=? and userid=? and timeblock=?", - (guild_id, user_id, timeblock.value), - ) - self.con.commit() - - def unsubscribe(self, guild_id: int, user_id: int) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "DELETE FROM users WHERE guildid=? and userid=?", (guild_id, user_id) - ) - self.con.commit() - - def query_timeblock(self, guild_id: int, timeblock: Timeblock) -> List[int]: - with closing(self.con.cursor()) as cur: - res = cur.execute( - "SELECT userid FROM users WHERE guildid=? and timeblock=?", - (guild_id, timeblock.value), - ) - userids = list(map(lambda x: x[0], res.fetchall())) - return userids - - def query_userid(self, guild_id: int, user_id: int) -> List[Timeblock]: - with closing(self.con.cursor()) as cur: - res = cur.execute( - "SELECT timeblock FROM users WHERE guildid=? and userid=?", - (guild_id, user_id), - ) - timeblocks = list(map(lambda x: Timeblock(x[0]), res.fetchall())) - return timeblocks - - -class PairingsDB: - def __init__(self, path: str) -> None: - self.db = path - self.con = sqlite3.connect(self.db) - self._setup() - - @staticmethod - def _serialize_userids(userids: List[int]) -> str: - return ",".join(map(str, sorted(userids))) - - @staticmethod - def _deserialize_userids(userids_ser: str) -> List[int]: - return list(map(int, userids_ser.split(","))) - - def _setup(self) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "CREATE TABLE IF NOT EXISTS pairings (guildid INTEGER, userids TEXT, channelid INTEGER, " - "threadid INTEGER, UNIQUE (guildid, userids, channelid, threadid))" - ) - self.con.commit() - - def insert( - self, guild_id: int, userids: List[int], channelid: int, threadid: int - ) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "INSERT INTO pairings VALUES (?, ?, ?, ?)", - (guild_id, PairingsDB._serialize_userids(userids), channelid, threadid), - ) - self.con.commit() - - def delete( - self, guild_id: int, userids: List[int], channelid: int, threadid: int - ) -> None: - with closing(self.con.cursor()) as cur: - cur.execute( - "DELETE from pairings WHERE guildid=? and userids=? and channelid=? and threadid=?", - (guild_id, PairingsDB._serialize_userids(userids), channelid, threadid), - ) - self.con.commit() - - def query_userids( - self, guild_id: int, userids: List[int], channelid: int - ) -> Optional[int]: - with closing(self.con.cursor()) as cur: - res = cur.execute( - "SELECT threadid FROM pairings WHERE guildid=? and userids=? and channelid=?", - (guild_id, PairingsDB._serialize_userids(userids), channelid), - ) - results = res.fetchall() - if len(results) == 1: - return results[0][0] diff --git a/pairbot/discord.py b/pairbot/discord.py deleted file mode 100644 index 0338556..0000000 --- a/pairbot/discord.py +++ /dev/null @@ -1,90 +0,0 @@ -"""discord.py - -Utilities for interfacing with the Discord API.""" - -from abc import ABC -from typing import ( - Any, - Callable, - Concatenate, - Coroutine, - ParamSpec, - TypeVar, - TypedDict, - Union, -) - -import functools -import logging - -import discord - -logger = logging.getLogger(__name__) - - -class AppContext(TypedDict): - logger_adapter: Union[logging.Logger, logging.LoggerAdapter] - - -class AppCommandLoggerAdapter(logging.LoggerAdapter): - """Custom LoggerAdapter for adding context to logs.""" - def process(self, msg: str, kwargs: dict): - """Add guild, channel, and user IDs to slash command logs.""" - if self.extra is None: - return msg, kwargs - - guild_id = self.extra.get("guild_id") - channel_id = self.extra.get("channel_id") - user_id = self.extra.get("user_id") - - s = "" - if guild_id is not None: s += f"g: <{guild_id} " - if channel_id is not None: s += f"c: <{channel_id} " - if user_id is not None: s += f"u: <{user_id} " - s += msg - - return s, kwargs - -# Type fuckery for the command decorator -T = TypeVar("T") -P = ParamSpec("P") -CommandCallback = Callable[Concatenate[AppContext, discord.Interaction[Any], P], Coroutine[Any, Any, T]] - - -class Application(ABC): - """A wrapper around a discord client instance.""" - - def __init__( - self, - intents: discord.Intents = discord.Intents.all(), - **options: Any - ): - self.client = discord.Client(intents=intents, **options) - self.command_tree = discord.app_commands.CommandTree(self.client) - - def get_context(self, interaction: discord.Interaction[Any]) -> AppContext: - return { - "logger_adapter": AppCommandLoggerAdapter( - logger, - { - "guild_id": interaction.guild_id, - "channel_id": interaction.channel_id, - "user_id": interaction.user.id, - } - ), - } - - def command(self, **options: Any): - """Wrapper for adding logging and error-handling to application slash commands.""" - def decorator(callback: CommandCallback): - @self.command_tree.command(**options) - @functools.wraps(callback) - async def wrapper(interaction: discord.Interaction[Any], *args: P.args, **kwargs: P.kwargs) -> None: - await callback( - self.get_context(interaction), - interaction, - *args, - **kwargs - ) - return wrapper - return decorator diff --git a/pairbot/globals.py b/pairbot/globals.py deleted file mode 100644 index ff9bb36..0000000 --- a/pairbot/globals.py +++ /dev/null @@ -1,6 +0,0 @@ -from contextvars import ContextVar - -from .context import AppContext, InteractionContext - -app: ContextVar[AppContext] = ContextVar("app") -interaction: ContextVar[InteractionContext] = ContextVar("interaction") diff --git a/pairbot/main.py b/pairbot/main.py index 1452a4e..0175605 100644 --- a/pairbot/main.py +++ b/pairbot/main.py @@ -206,22 +206,43 @@ async def fail_on_nonexistent_subscription( def get_pairing( session: sqlalchemy.orm.Session, - interaction: discord.Interaction, user_1: discord.User | discord.Member, - user_2: discord.Member + user_2: discord.User | discord.Member, + user_3: Optional[discord.User | discord.Member] ) -> Optional[Pairing]: - return ( - session.query(Pairing) - .filter(sqlalchemy.or_( - Pairing.user_1_id == user_1.id, - Pairing.user_2_id == user_1.id, - )) - .filter(sqlalchemy.or_( - Pairing.user_1_id == user_2.id, - Pairing.user_2_id == user_2.id, - )) - .one_or_none() - ) + if user_3 is None: + return ( + session.query(Pairing) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_1.id, + Pairing.user_2_id == user_1.id, + )) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_2.id, + Pairing.user_2_id == user_2.id, + )) + .one_or_none() + ) + else: + return ( + session.query(Pairing) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_1.id, + Pairing.user_2_id == user_1.id, + Pairing.user_3_id == user_1.id, + )) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_2.id, + Pairing.user_2_id == user_2.id, + Pairing.user_3_id == user_2.id, + )) + .filter(sqlalchemy.or_( + Pairing.user_1_id == user_3.id, + Pairing.user_2_id == user_3.id, + Pairing.user_3_id == user_3.id, + )) + .one_or_none() + ) # Slash commands @@ -263,7 +284,6 @@ async def _add_pairbot(interaction: discord.Interaction): await interaction.response.send_message( "Added Pairbot to \"#%s\"." % interaction.channel.name, - ephemeral=True, ) @@ -298,7 +318,6 @@ async def _remove_pairbot(interaction: discord.Interaction): await interaction.response.send_message( "Removed Pairbot from \"#%s\"." % interaction.channel.name, - ephemeral=True, ) @@ -315,7 +334,7 @@ async def _subscribe( assert isinstance(interaction.channel, discord.TextChannel) with Session.begin() as session: - if fail_on_inactive_channel(session, interaction): + if await fail_on_inactive_channel(session, interaction): return schedule = get_user_schedule(session, interaction) @@ -325,7 +344,7 @@ async def _subscribe( else: app.logger.info("Found existing schedule") - if fail_on_existing_subscription(interaction, schedule, weekday): + if await fail_on_existing_subscription(interaction, schedule, weekday): return if weekday is not None: @@ -357,7 +376,7 @@ async def _unsubscribe( assert isinstance(interaction.channel, discord.TextChannel) with Session.begin() as session: - if fail_on_inactive_channel(session, interaction): + if await fail_on_inactive_channel(session, interaction): return schedule = get_user_schedule(session, interaction) @@ -369,13 +388,13 @@ async def _unsubscribe( ) return - if fail_on_nonexistent_subscription(interaction, schedule, weekday): + if await fail_on_nonexistent_subscription(interaction, schedule, weekday): return if weekday is not None: schedule.set_availability_on(weekday, False) await interaction.response.send_message( - ("Successfully unsubscribed from pair programming on %s in #%s." % weekday, interaction.channel.name), + ("Successfully unsubscribed from pair programming on %s in #%s." % (weekday, interaction.channel.name)), ephemeral=True ) else: @@ -400,12 +419,12 @@ async def _skip( assert isinstance(interaction.channel, discord.TextChannel) with Session.begin() as session: - if fail_on_inactive_channel(session, interaction): + if await fail_on_inactive_channel(session, interaction): return schedule = get_user_schedule(session, interaction) - if fail_on_nonexistent_subscription(interaction, schedule, None): + if await fail_on_nonexistent_subscription(interaction, schedule, None): return assert schedule is not None @@ -431,7 +450,7 @@ async def _skip( skipped_weekday = Weekday(skipped_date.weekday()) - if fail_on_nonexistent_subscription(interaction, schedule, skipped_weekday): + if await fail_on_nonexistent_subscription(interaction, schedule, skipped_weekday): return skip = get_skip(session, interaction, skipped_date) @@ -466,12 +485,12 @@ async def _unskip( assert isinstance(interaction.channel, discord.TextChannel) with Session.begin() as session: - if fail_on_inactive_channel(session, interaction): + if await fail_on_inactive_channel(session, interaction): return schedule = get_user_schedule(session, interaction) - if fail_on_nonexistent_subscription(interaction, schedule, None): + if await fail_on_nonexistent_subscription(interaction, schedule, None): return assert schedule is not None @@ -496,7 +515,7 @@ async def _unskip( unskipped_weekday = Weekday(unskipped_date.weekday()) - if fail_on_nonexistent_subscription(interaction, schedule, unskipped_weekday): + if await fail_on_nonexistent_subscription(interaction, schedule, unskipped_weekday): return skip = get_skip(session, interaction, unskipped_date) @@ -529,7 +548,7 @@ async def _view_schedule( assert isinstance(interaction.channel, discord.TextChannel) with Session.begin() as session: - if fail_on_inactive_channel(session, interaction): + if await fail_on_inactive_channel(session, interaction): return all_schedules = ( @@ -621,10 +640,10 @@ async def _pair_with( return with Session.begin() as session: - if fail_on_inactive_channel(session, interaction): + if await fail_on_inactive_channel(session, interaction): return - pairing = get_pairing(session, interaction, interaction.user, user) + pairing = get_pairing(session, interaction.user, user, None) usernames = sorted([user.name for user in (interaction.user, user)]) @@ -635,6 +654,7 @@ async def _pair_with( ) pairing = Pairing( + thread_id=thread.id, channel_id=interaction.channel_id, user_1_id=interaction.user.id, user_2_id=user.id, @@ -651,6 +671,7 @@ async def _pair_with( ) pairing = Pairing( + thread_id=thread.id, channel_id=interaction.channel_id, user_1_id=interaction.user.id, user_2_id=user.id, @@ -666,6 +687,22 @@ async def _pair_with( ) +@app.command( + name="makegroups", + description="Assign random groups to people subscribed to pairing today." +) +@discord.app_commands.guild_only() +@discord.app_commands.checks.has_permissions(administrator=True) +async def _make_groups( + interaction: discord.Interaction[Any] +): + await make_groups() + await interaction.response.send_message( + f"Great success!! :-)", + ephemeral=True + ) + + @discord.ext.tasks.loop(time=config.PAIRING_TIME) async def make_groups(): app.logger.info("Creating random groups.") @@ -719,31 +756,54 @@ async def make_groups(): groups.append(pair) for group in groups: - usernames = sorted([user.name for user in group]) + usernames = sorted([user.name for user in group if user is not None]) user1, user2, user3 = group - # TODO find existing pairing and get thread - pairing = Pairing( - channel_id=channel.id, - user_1_id=user1.id, - user_2_id=user2.id, - user_3_id=user3.id if user3 is not None else None, + pairing = get_pairing(session, user1, user2, user3) + + if pairing is None: + # Make new discord thread + thread = await channel.create_thread( + name=" & ".join(usernames) + ) + + pairing = Pairing( + thread_id=thread.id, + channel_id=channel.id, + user_1_id=user1.id, + user_2_id=user2.id, + user_3_id=user3.id if user3 is not None else None, + ) + session.add(pairing) + else: + thread = guild.get_thread(pairing.thread_id) + if thread is None: + # Thread deleted somehow. Start over + session.delete(pairing) + + # Make new discord thread + thread = await channel.create_thread( + name=" & ".join(usernames) + ) + + pairing = Pairing( + thread_id=thread.id, + channel_id=channel.id, + user_1_id=user1.id, + user_2_id=user2.id, + user_3_id=user3.id if user3 is not None else None, + ) + session.add(pairing) + + msg = ", ".join(("<@%d>" % user.id) for user in group if user is not None) + msg += ": " + msg += "You have been matched together. Happy pairing! :computer:" + + await thread.send(msg) + + app.logger.info( + "Made pairing group with %s" % " & ".join(usernames) ) - session.add(pairing) - - # Make new discord thread - thread = await channel.create_thread( - name=" & ".join(usernames) - ) - - await thread.send( - " ".join("<@%d>" % (user.id for user in group if user is not None)) - ) - await thread.send( - "You have been matched together. Happy pairing! :computer:" - ) - - app.logger.info("Send out pairings.") def run(): diff --git a/pairbot/utils.py b/pairbot/utils.py deleted file mode 100644 index ab3ee51..0000000 --- a/pairbot/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import argparse -import json - -import discord -import discord.ext.commands - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-d", "--dev", action="store_true", help="use dev environment") - return parser.parse_args() - - -def read_guild_to_channel(path: str): - with open(path, "r") as f: - return json.load(f) - - -def get_user_name(user: discord.User): - if user.global_name is not None: - return user.global_name - elif user.nick is not None: - return user.nick - else: - return user.name