Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions .github/workflows/docker-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,42 @@ on:
branches:
- main
types: [closed]
workflow_dispatch:

jobs:
build-and-push:
runs-on: ubuntu-latest
environment: DOCKER
if: github.event_name != 'pull_request' || github.event.pull_request.merged == true

steps:
- name: Check out the repo
uses: actions/checkout@v4

- name: Set up QEMU
uses: docker/setup-qemu-action@v3

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver-opts: |
image=moby/buildkit:buildx-stable-1

- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}

- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
with:
images: asternic/wuzapi
images: devlucasmoraes/wuzapi
tags: |
type=raw,value=latest
type=sha,format=short

- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
Expand All @@ -52,4 +53,4 @@ jobs:
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
cache-to: type=gha,mode=max
45 changes: 27 additions & 18 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,16 @@ func (s *server) authalice(next http.Handler) http.Handler {
if !found {
log.Info().Msg("Looking for user information in DB")
// Checks DB from matching user and store user values in context
rows, err := s.db.Query("SELECT id,name,webhook,jid,events,proxy_url,qrcode,history,hmac_key IS NOT NULL AND length(hmac_key) > 0 FROM users WHERE token=$1 LIMIT 1", token)
rows, err := s.db.Query("SELECT id,name,webhook,jid,events,proxy_url,qrcode,history,hmac_key IS NOT NULL AND length(hmac_key) > 0,CASE WHEN s3_enabled THEN 'true' ELSE 'false' END,COALESCE(media_delivery, 'base64') FROM users WHERE token=$1 LIMIT 1", token)
if err != nil {
s.Respond(w, r, http.StatusInternalServerError, err)
return
}
defer rows.Close()
var history sql.NullInt64
var s3Enabled, mediaDelivery string
for rows.Next() {
err = rows.Scan(&txtid, &name, &webhook, &jid, &events, &proxy_url, &qrcode, &history, &hasHmac)
err = rows.Scan(&txtid, &name, &webhook, &jid, &events, &proxy_url, &qrcode, &history, &hasHmac, &s3Enabled, &mediaDelivery)
if err != nil {
s.Respond(w, r, http.StatusInternalServerError, err)
return
Expand All @@ -176,16 +177,18 @@ func (s *server) authalice(next http.Handler) http.Handler {
log.Debug().Str("userId", txtid).Bool("historyValid", history.Valid).Int64("historyValue", history.Int64).Str("historyStr", historyStr).Msg("User authentication - history debug")

v := Values{map[string]string{
"Id": txtid,
"Name": name,
"Jid": jid,
"Webhook": webhook,
"Token": token,
"Proxy": proxy_url,
"Events": events,
"Qrcode": qrcode,
"History": historyStr,
"HasHmac": strconv.FormatBool(hasHmac),
"Id": txtid,
"Name": name,
"Jid": jid,
"Webhook": webhook,
"Token": token,
"Proxy": proxy_url,
"Events": events,
"Qrcode": qrcode,
"History": historyStr,
"HasHmac": strconv.FormatBool(hasHmac),
"S3Enabled": s3Enabled,
"MediaDelivery": mediaDelivery,
}}

userinfocache.Set(token, v, cache.NoExpiration)
Expand Down Expand Up @@ -5260,7 +5263,15 @@ func (s *server) DeleteUserComplete() http.HandlerFunc {
client.Disconnect()
}

// 2. Remove from DB
// 2. Query S3 config before deleting the user
var s3Enabled bool
err = s.db.QueryRow("SELECT s3_enabled FROM users WHERE id = $1", id).Scan(&s3Enabled)
if err != nil {
log.Error().Err(err).Str("id", id).Msg("problem retrieving user s3 configuration")
// Continue anyway since we have the ID to delete local files
}

// 3. Remove from DB
_, err = s.db.Exec("DELETE FROM users WHERE id = $1", id)
if err != nil {
s.respondWithJSON(w, http.StatusInternalServerError, map[string]interface{}{
Expand All @@ -5272,13 +5283,13 @@ func (s *server) DeleteUserComplete() http.HandlerFunc {
return
}

// 3. Cleanup from memory
// 4. Cleanup from memory
clientManager.DeleteWhatsmeowClient(id)
clientManager.DeleteMyClient(id)
clientManager.DeleteHTTPClient(id)
userinfocache.Delete(token)

// 4. Remove media files
// 5. Remove media files
userDirectory := filepath.Join(s.exPath, "files", id)
if stat, err := os.Stat(userDirectory); err == nil && stat.IsDir() {
log.Info().Str("dir", userDirectory).Msg("deleting media and history files from disk")
Expand All @@ -5288,9 +5299,7 @@ func (s *server) DeleteUserComplete() http.HandlerFunc {
}
}

// 5. Remove files from S3 (if enabled)
var s3Enabled bool
err = s.db.QueryRow("SELECT s3_enabled FROM users WHERE id = $1", id).Scan(&s3Enabled)
// 6. Remove files from S3 (if enabled)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a subtle bug in the if condition on the next line (if err == nil && s3Enabled). The err variable it checks is from the os.RemoveAll operation on line 5296. If local file deletion fails, this will prevent S3 objects from being deleted.

To fix this, the condition on line 5303 should be changed to just if s3Enabled. The s3Enabled flag will correctly be false if the initial database query for it failed, so checking err here is both incorrect and unnecessary.

if err == nil && s3Enabled {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
Expand Down
1 change: 1 addition & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ func ProcessOutgoingMedia(userID string, contactJID string, messageID string, da

// Process S3 upload if enabled
if s3Config.Enabled && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(userID)
// Process S3 upload (outgoing messages are always in outbox)
s3Data, err := GetS3Manager().ProcessMediaForS3(
context.Background(),
Expand Down
3 changes: 3 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ func main() {
}
}()

// Set DB reference in S3Manager for lazy client initialization
GetS3Manager().SetDB(db)

// Initialize the schema
if err = initializeSchema(db); err != nil {
log.Fatal().Err(err).Msg("Failed to initialize schema")
Expand Down
60 changes: 59 additions & 1 deletion s3manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/jmoiron/sqlx"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
Expand All @@ -32,6 +33,7 @@ type S3Config struct {
// S3Manager manages S3 operations
type S3Manager struct {
mu sync.RWMutex
db *sqlx.DB
clients map[string]*s3.Client
configs map[string]*S3Config
}
Expand All @@ -47,6 +49,56 @@ func GetS3Manager() *S3Manager {
return s3Manager
}

// SetDB sets the database reference for lazy S3 client initialization
func (m *S3Manager) SetDB(db *sqlx.DB) {
m.mu.Lock()
defer m.mu.Unlock()
m.db = db
}

// EnsureClientFromDB loads S3 config from DB and initializes client if enabled. Returns true if client is available.
func (m *S3Manager) EnsureClientFromDB(userID string) bool {
if _, _, ok := m.GetClient(userID); ok {
return true
}
m.mu.RLock()
db := m.db
m.mu.RUnlock()
if db == nil {
return false
}
var s3DbConfig struct {
Enabled bool `db:"s3_enabled"`
Endpoint string `db:"s3_endpoint"`
Region string `db:"s3_region"`
Bucket string `db:"s3_bucket"`
AccessKey string `db:"s3_access_key"`
SecretKey string `db:"s3_secret_key"`
PathStyle bool `db:"s3_path_style"`
PublicURL string `db:"s3_public_url"`
MediaDelivery string `db:"media_delivery"`
RetentionDays int `db:"s3_retention_days"`
}
query := `SELECT s3_enabled, s3_endpoint, s3_region, s3_bucket, s3_access_key, s3_secret_key, s3_path_style, s3_public_url, COALESCE(media_delivery, 'base64') AS media_delivery, COALESCE(s3_retention_days, 30) AS s3_retention_days FROM users WHERE id = $1`
query = db.Rebind(query)
if err := db.Get(&s3DbConfig, query, userID); err != nil || !s3DbConfig.Enabled {
return false
}
config := &S3Config{
Enabled: s3DbConfig.Enabled,
Endpoint: s3DbConfig.Endpoint,
Region: s3DbConfig.Region,
Bucket: s3DbConfig.Bucket,
AccessKey: s3DbConfig.AccessKey,
SecretKey: s3DbConfig.SecretKey,
PathStyle: s3DbConfig.PathStyle,
PublicURL: s3DbConfig.PublicURL,
MediaDelivery: s3DbConfig.MediaDelivery,
RetentionDays: s3DbConfig.RetentionDays,
}
return m.InitializeS3Client(userID, config) == nil
}

// InitializeS3Client creates or updates S3 client for a user
func (m *S3Manager) InitializeS3Client(userID string, config *S3Config) error {
if !config.Enabled {
Expand Down Expand Up @@ -192,7 +244,13 @@ func (m *S3Manager) GenerateS3Key(userID, contactJID, messageID string, mimeType
func (m *S3Manager) UploadToS3(ctx context.Context, userID string, key string, data []byte, mimeType string) error {
client, config, ok := m.GetClient(userID)
if !ok {
return fmt.Errorf("S3 client not initialized for user %s", userID)
// Try lazy init from DB if available (handles reconnect-after-restart)
if m.EnsureClientFromDB(userID) {
client, config, ok = m.GetClient(userID)
}
if !ok {
return fmt.Errorf("S3 client not initialized for user %s", userID)
}
}

// Set content type and cache headers for preview
Expand Down
62 changes: 19 additions & 43 deletions wmiau.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ type MyClient struct {
s *server
}

// ensureS3ClientForUser loads S3 config from DB and initializes client if not already present (lazy init for reconnect-after-restart)
func ensureS3ClientForUser(userID string) {
GetS3Manager().EnsureClientFromDB(userID)
}

func sendToGlobalWebHook(jsonData []byte, token string, userID string) {
jsonDataStr := string(jsonData)

Expand Down Expand Up @@ -295,49 +300,7 @@ func (s *server) connectOnStartup() {

// Initialize S3 client if configured
go func(userID string) {
var s3Config struct {
Enabled bool `db:"s3_enabled"`
Endpoint string `db:"s3_endpoint"`
Region string `db:"s3_region"`
Bucket string `db:"s3_bucket"`
AccessKey string `db:"s3_access_key"`
SecretKey string `db:"s3_secret_key"`
PathStyle bool `db:"s3_path_style"`
PublicURL string `db:"s3_public_url"`
RetentionDays int `db:"s3_retention_days"`
}

err := s.db.Get(&s3Config, `
SELECT s3_enabled, s3_endpoint, s3_region, s3_bucket,
s3_access_key, s3_secret_key, s3_path_style,
s3_public_url, s3_retention_days
FROM users WHERE id = $1`, userID)

if err != nil {
log.Error().Err(err).Str("userID", userID).Msg("Failed to get S3 config")
return
}

if s3Config.Enabled {
config := &S3Config{
Enabled: s3Config.Enabled,
Endpoint: s3Config.Endpoint,
Region: s3Config.Region,
Bucket: s3Config.Bucket,
AccessKey: s3Config.AccessKey,
SecretKey: s3Config.SecretKey,
PathStyle: s3Config.PathStyle,
PublicURL: s3Config.PublicURL,
RetentionDays: s3Config.RetentionDays,
}

err = GetS3Manager().InitializeS3Client(userID, config)
if err != nil {
log.Error().Err(err).Str("userID", userID).Msg("Failed to initialize S3 client on startup")
} else {
log.Info().Str("userID", userID).Msg("S3 client initialized on startup")
}
}
GetS3Manager().EnsureClientFromDB(userID)
}(txtid)
}
}
Expand Down Expand Up @@ -461,6 +424,9 @@ func (s *server) startClient(userID string, textjid string, token string, subscr
}
clientManager.SetHTTPClient(userID, httpClient)

// Initialize S3 client if configured (needed when user reconnects after container restart - connectOnStartup only runs for connected=1)
GetS3Manager().EnsureClientFromDB(userID)

if client.Store.ID == nil {
// No ID stored, new login
qrChan, err := client.GetQRChannel(context.Background())
Expand Down Expand Up @@ -817,6 +783,11 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {
s3Config.MediaDelivery = myuserinfo.(Values).Get("MediaDelivery")
}

// Lazy init S3 client if needed (handles reconnect-after-restart when connectOnStartup skipped this user)
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
}

postmap["type"] = "Message"
dowebhook = 1
metaParts := []string{fmt.Sprintf("pushname: %s", evt.Info.PushName), fmt.Sprintf("timestamp: %s", evt.Info.Timestamp)}
Expand Down Expand Up @@ -867,6 +838,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// Process S3 upload if enabled
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This call to ensureS3ClientForUser(txtid) is redundant because an identical check and call is already performed on line 787 for all message events. This redundant call is repeated for all media types (audio, document, video, sticker).

Please remove this line and the similar ones for other media types to improve clarity and avoid unnecessary function calls.

// Get sender JID for inbox/outbox determination
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
Expand Down Expand Up @@ -955,6 +927,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// Process S3 upload if enabled
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
// Get sender JID for inbox/outbox determination
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
Expand Down Expand Up @@ -1048,6 +1021,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// Process S3 upload if enabled
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
// Get sender JID for inbox/outbox determination
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
Expand Down Expand Up @@ -1130,6 +1104,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// Process S3 upload if enabled
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
// Get sender JID for inbox/outbox determination
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
Expand Down Expand Up @@ -1212,6 +1187,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// if using S3 (same stream as other media)
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
if evt.Info.IsGroup {
Expand Down
Loading