Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 35 additions & 36 deletions src/webhooks/actix_web.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,61 @@
use crate::Incoming;
use actix_web::{
dev::Payload,
error::{Error, ErrorBadRequest, ErrorUnauthorized},
web::Json,
FromRequest, HttpRequest,
};
use serde::de::DeserializeOwned;
use super::IncomingPayload;
use std::{
future::Future,
pin::Pin,
task::{ready, Context, Poll},
task::{Context, Poll, ready},
};

use actix_web::{
FromRequest, HttpRequest,
dev::Payload,
error::{Error, ErrorBadRequest, ErrorUnauthorized},
};
use futures_core::stream::Stream;

#[doc(hidden)]
pub struct IncomingFut<T: DeserializeOwned> {
pub struct IncomingPayloadFut {
req: HttpRequest,
json_fut: <Json<T> as FromRequest>::Future,
payload: Payload,
body: Vec<u8>,
}

impl<T> Future for IncomingFut<T>
where
T: DeserializeOwned,
{
type Output = Result<Incoming<T>, Error>;
impl Future for IncomingPayloadFut {
type Output = Result<IncomingPayload, Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Ok(json) = ready!(Pin::new(&mut self.json_fut).poll(cx)) {
let headers = self.req.headers();

if let Some(authorization) = headers.get("Authorization") {
if let Ok(authorization) = authorization.to_str() {
return Poll::Ready(Ok(Incoming {
authorization: authorization.to_owned(),
data: json.into_inner(),
}));
}
while let Some(body) = ready!(Pin::new(&mut self.payload).poll_next(cx)) {
match body {
Ok(body) => self.body.extend_from_slice(&body),

Err(_) => return Poll::Ready(Err(ErrorBadRequest("400"))),
}
}

let headers = self.req.headers();

return Poll::Ready(Err(ErrorUnauthorized("401")));
if let (Some(signature), Some(trace)) = (
headers.get("x-topgg-signature"),
headers.get("x-topgg-trace"),
) && let (Ok(signature), Ok(trace)) = (signature.to_str(), trace.to_str())
&& let Some(incoming) = IncomingPayload::new(signature, self.body.clone(), trace)
{
return Poll::Ready(Ok(incoming));
}

Poll::Ready(Err(ErrorBadRequest("400")))
Poll::Ready(Err(ErrorUnauthorized("401")))
}
}

#[cfg_attr(docsrs, doc(cfg(feature = "actix-web")))]
impl<T> FromRequest for Incoming<T>
where
T: DeserializeOwned,
{
impl FromRequest for IncomingPayload {
type Error = Error;
type Future = IncomingFut<T>;
type Future = IncomingPayloadFut;

#[inline(always)]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
IncomingFut {
IncomingPayloadFut {
req: req.clone(),
json_fut: Json::from_request(req, payload),
payload: payload.take(),
body: vec![],
}
}
}
95 changes: 57 additions & 38 deletions src/webhooks/axum.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,69 @@
use super::Webhook;
use super::Payload;
use std::sync::Arc;

use axum::{
Router,
extract::State,
http::{HeaderMap, StatusCode},
response::IntoResponse,
response::{IntoResponse, Response},
routing::post,
Router,
};
use serde::de::DeserializeOwned;
use std::sync::Arc;

/// An axum webhook listener for listening to payloads.
///
/// # Example
///
/// ```rust,no_run
/// struct MyTopggListener {}
///
/// #[async_trait::async_trait]
/// impl topgg::axum::Listener for MyTopggListener {
/// async fn callback(self: Arc<Self>, payload: Payload, _trace: &str) -> Response {
/// println!("{payload:?}");
///
/// (StatusCode::NO_CONTENT, ()).into_response()
/// }
/// }
/// ```
#[async_trait::async_trait]
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
pub trait Listener: Send + Sync + 'static {
async fn callback(self: Arc<Self>, payload: Payload, trace: &str) -> Response;
}

struct WebhookState<T> {
state: Arc<T>,
password: Arc<String>,
secret: Arc<String>,
}

impl<T> Clone for WebhookState<T> {
#[inline(always)]
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
password: Arc::clone(&self.password),
state: self.state.clone(),
secret: self.secret.clone(),
}
}
}

/// Creates a new axum [`Router`] for receiving vote events.
/// Creates a new axum [`Router`] for receiving webhook payloads.
///
/// # Example
///
/// ```rust,no_run
/// use axum::{routing::get, Router};
/// use topgg::{VoteEvent, Webhook};
/// use tokio::net::TcpListener;
/// use topgg::Payload;
/// use std::sync::Arc;
///
/// struct MyVoteListener {}
/// use axum::{http::status::StatusCode, response::{IntoResponse, Response}, routing::get, Router};
/// use tokio::net::TcpListener;
///
/// struct MyTopggListener {}
///
/// #[async_trait::async_trait]
/// impl Webhook<VoteEvent> for MyVoteListener {
/// async fn callback(&self, vote: VoteEvent) {
/// println!("A user with the ID of {} has voted us on Top.gg!", vote.voter_id);
/// impl topgg::axum::Listener for MyTopggListener {
/// async fn callback(self: Arc<Self>, payload: Payload, _trace: &str) -> Response {
/// println!("{payload:?}");
///
/// (StatusCode::NO_CONTENT, ()).into_response()
/// }
/// }
///
Expand All @@ -49,48 +73,43 @@ impl<T> Clone for WebhookState<T> {
///
/// #[tokio::main]
/// async fn main() {
/// let state = Arc::new(MyVoteListener {});
/// let state = Arc::new(MyTopggListener {});
///
/// let router = Router::new().route("/", get(index)).nest(
/// "/votes",
/// topgg::axum::webhook(env!("MY_TOPGG_WEBHOOK_SECRET").to_string(), Arc::clone(&state)),
/// "/webhook",
/// topgg::axum::webhook(Arc::clone(&state), env!("TOPGG_WEBHOOK_SECRET").to_string()),
/// );
///
/// let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
///
/// axum::serve(listener, router).await.unwrap();
/// }
/// ```
#[inline(always)]
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
pub fn webhook<D, T>(password: String, state: Arc<T>) -> Router
pub fn webhook<S>(state: Arc<S>, secret: String) -> Router
where
D: DeserializeOwned + Send,
T: Webhook<D>,
S: Listener,
{
Router::new()
.route(
"/",
post(
async |headers: HeaderMap, State(webhook): State<WebhookState<T>>, body: String| {
if let Some(authorization) = headers.get("Authorization") {
if let Ok(authorization) = authorization.to_str() {
if authorization == *(webhook.password) {
if let Ok(data) = serde_json::from_str(&body) {
webhook.state.callback(data).await;

return (StatusCode::NO_CONTENT, ()).into_response();
}
}
}
async |headers: HeaderMap, State(wrapped_state): State<WebhookState<S>>, body: String| {
if let Some(signature) = headers.get("x-topgg-signature")
&& let Ok(signature) = signature.to_str()
&& let Some(trace) = headers.get("x-topgg-trace")
&& let Ok(trace) = trace.to_str()
&& let Some(payload) = Payload::new(signature, &body, &wrapped_state.secret)
{
wrapped_state.state.callback(payload, trace).await
} else {
(StatusCode::UNAUTHORIZED, ()).into_response()
}

(StatusCode::UNAUTHORIZED, ()).into_response()
},
),
)
.with_state(WebhookState {
state,
password: Arc::new(password),
secret: Arc::new(secret),
})
}
Loading