package tools import ( "crypto/sha256" "encoding/hex" "fmt" "math/bits" "net/http" "strconv" "sync" "time" "github.com/jackc/pgx/v5" ) type ChallengeSessionData struct { Expires int64 Difficulty int } type RatelimitEntry struct { Reset int64 Usage int32 } type RatelimitShard struct { sync.Mutex data map[string]RatelimitEntry } const ( RatelimitShardCount = 256 ) var ( ChallengeAtomic sync.Mutex ChallengeSession = make(map[string]ChallengeSessionData, 1024) RatelimitShards = make([]RatelimitShard, RatelimitShardCount) ) func init() { for i := range RatelimitShards { RatelimitShards[i].data = make(map[string]RatelimitEntry, 128) } // Cleanup Challenges go func() { t := time.NewTicker(10 * time.Minute) defer t.Stop() for range t.C { now := time.Now().UnixNano() ChallengeAtomic.Lock() for k, v := range ChallengeSession { if now > v.Expires { delete(ChallengeSession, k) } } ChallengeAtomic.Unlock() } }() // Cleanup Ratelimits go func() { t := time.NewTicker(10 * time.Minute) defer t.Stop() for range t.C { now := time.Now().UnixNano() for i := range RatelimitShards { s := &RatelimitShards[i] s.Lock() for k, v := range s.data { if now >= v.Reset { delete(s.data, k) } } s.Unlock() } } }() } func NewChallenger(minimumDifficulty int) ChainMiddleware { return func(w http.ResponseWriter, r *http.Request) bool { // --- Proof of Work --- var givenNonce = r.Header.Get("X-Challenge-Nonce") var givenCounterRaw = r.Header.Get("X-Challenge-Counter") var givenCounter = 0 if _, err := hex.DecodeString(givenNonce); err != nil { SendClientError(w, r, ERROR_BODY_INVALID_CHALLENGE) return false } if v, err := strconv.Atoi(givenCounterRaw); err != nil || v < 0 { SendClientError(w, r, ERROR_BODY_INVALID_CHALLENGE) return false } else { givenCounter = v } // Consume Session ChallengeAtomic.Lock() session, exists := ChallengeSession[givenNonce] if !exists { ChallengeAtomic.Unlock() SendClientError(w, r, ERROR_UNKNOWN_CHALLENGE) return false } delete(ChallengeSession, givenNonce) ChallengeAtomic.Unlock() if time.Now().Unix() >= session.Expires { SendClientError(w, r, ERROR_CHALLENGE_EXPIRED) return false } if session.Difficulty < minimumDifficulty { SendClientError(w, r, ERROR_CHALLENGE_TOO_EASY) return false } // Validate Results sessionInput := fmt.Sprintf("%s%d", givenNonce, givenCounter) sessionHash := sha256.Sum256([]byte(sessionInput)) zeroBitsRequired := session.Difficulty zeroBitsFound := 0 for _, b := range sessionHash { if b == 0 { zeroBitsFound += 8 } else { zeroBitsFound += bits.LeadingZeros8(b) break } } if zeroBitsFound < zeroBitsRequired { SendClientError(w, r, ERROR_CHALLENGE_INVALID) return false } return true } } // Prevent Spam by Limiting Amount of Incoming Requests func NewRatelimiter(categoryName string, limit int32, period time.Duration) ChainMiddleware { limitStr := strconv.Itoa(int(limit)) return func(w http.ResponseWriter, r *http.Request) bool { name := categoryName + ":" + RequestAddressHash(r) // FNV-1a 32-bit var h uint64 = 2166136261 for i := 0; i < len(name); i++ { h ^= uint64(name[i]) h *= 16777619 } // Calculate Usage s := &RatelimitShards[h%RatelimitShardCount] s.Lock() now := time.Now() e, ok := s.data[name] if !ok || now.UnixNano() >= e.Reset { e = RatelimitEntry{Reset: now.Add(period).UnixNano(), Usage: 1} } else { e.Usage++ } s.data[name] = e s.Unlock() // Generate Headers resetSecs := strconv.FormatFloat(float64(e.Reset-now.UnixNano())/float64(time.Second), 'f', 2, 64) remaining := max(0, limit-e.Usage) hdr := w.Header() hdr.Set("X-Ratelimit-Category", categoryName) hdr.Set("X-Ratelimit-Reset", resetSecs) hdr.Set("X-Ratelimit-Limit", limitStr) hdr.Set("X-Ratelimit-Remaining", strconv.Itoa(int(remaining))) if e.Usage > int32(limit) { SendClientError(w, r, ERROR_GENERIC_RATELIMIT) return false } return true } } // Create Gatekeeper func NewGatekeeper(allowModerators bool) ChainMiddleware { return func(w http.ResponseWriter, r *http.Request) bool { if allowModerators { var Count int err := Database.QueryRow(r.Context(), "SELECT 1 FROM gifuu.mod_key WHERE token_hash = $1", RequestHash(r.Header.Get("Authorization")), ).Scan( &Count, ) if err != nil && err != pgx.ErrNoRows { SendServerError(w, r, err) return false } if err == pgx.ErrNoRows { SendClientError(w, r, ERROR_GENERIC_UNAUTHORIZED) return false } return true } return false } } type ChainMiddleware func(w http.ResponseWriter, r *http.Request) bool // Apply Middleware before Processing Request func Chain(h http.HandlerFunc, wares ...ChainMiddleware) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { for _, mw := range wares { if !mw(w, r) { return } } h(w, r) } } type MethodHandler map[string]http.HandlerFunc func (mh MethodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if handler, ok := mh[r.Method]; ok { handler(w, r) } else { SendClientError(w, r, ERROR_GENERIC_METHOD_NOT_ALLOWED) } }