Files
2026-05-23 17:17:56 -07:00

238 lines
5.2 KiB
Go

package tools
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"math/bits"
"net/http"
"strconv"
"sync"
"time"
"github.com/jackc/pgx/v5"
)
type ChallengeSessionData struct {
Expires int64
Difficulty int
}
type RatelimitEntry struct {
Reset int64
Usage int32
}
type RatelimitShard struct {
sync.Mutex
data map[string]RatelimitEntry
}
const (
RatelimitShardCount = 256
)
var (
ChallengeAtomic sync.Mutex
ChallengeSession = make(map[string]ChallengeSessionData, 1024)
RatelimitShards = make([]RatelimitShard, RatelimitShardCount)
)
func init() {
for i := range RatelimitShards {
RatelimitShards[i].data = make(map[string]RatelimitEntry, 128)
}
// Cleanup Challenges
go func() {
t := time.NewTicker(10 * time.Minute)
defer t.Stop()
for range t.C {
now := time.Now().UnixNano()
ChallengeAtomic.Lock()
for k, v := range ChallengeSession {
if now > v.Expires {
delete(ChallengeSession, k)
}
}
ChallengeAtomic.Unlock()
}
}()
// Cleanup Ratelimits
go func() {
t := time.NewTicker(10 * time.Minute)
defer t.Stop()
for range t.C {
now := time.Now().UnixNano()
for i := range RatelimitShards {
s := &RatelimitShards[i]
s.Lock()
for k, v := range s.data {
if now >= v.Reset {
delete(s.data, k)
}
}
s.Unlock()
}
}
}()
}
func NewChallenger(minimumDifficulty int) ChainMiddleware {
return func(w http.ResponseWriter, r *http.Request) bool {
// --- Proof of Work ---
var givenNonce = r.Header.Get("X-Challenge-Nonce")
var givenCounterRaw = r.Header.Get("X-Challenge-Counter")
var givenCounter = 0
if _, err := hex.DecodeString(givenNonce); err != nil {
SendClientError(w, r, ERROR_BODY_INVALID_CHALLENGE)
return false
}
if v, err := strconv.Atoi(givenCounterRaw); err != nil || v < 0 {
SendClientError(w, r, ERROR_BODY_INVALID_CHALLENGE)
return false
} else {
givenCounter = v
}
// Consume Session
ChallengeAtomic.Lock()
session, exists := ChallengeSession[givenNonce]
if !exists {
ChallengeAtomic.Unlock()
SendClientError(w, r, ERROR_UNKNOWN_CHALLENGE)
return false
}
delete(ChallengeSession, givenNonce)
ChallengeAtomic.Unlock()
if time.Now().Unix() >= session.Expires {
SendClientError(w, r, ERROR_CHALLENGE_EXPIRED)
return false
}
if session.Difficulty < minimumDifficulty {
SendClientError(w, r, ERROR_CHALLENGE_TOO_EASY)
return false
}
// Validate Results
sessionInput := fmt.Sprintf("%s%d", givenNonce, givenCounter)
sessionHash := sha256.Sum256([]byte(sessionInput))
zeroBitsRequired := session.Difficulty
zeroBitsFound := 0
for _, b := range sessionHash {
if b == 0 {
zeroBitsFound += 8
} else {
zeroBitsFound += bits.LeadingZeros8(b)
break
}
}
if zeroBitsFound < zeroBitsRequired {
SendClientError(w, r, ERROR_CHALLENGE_INVALID)
return false
}
return true
}
}
// Prevent Spam by Limiting Amount of Incoming Requests
func NewRatelimiter(categoryName string, limit int32, period time.Duration) ChainMiddleware {
limitStr := strconv.Itoa(int(limit))
return func(w http.ResponseWriter, r *http.Request) bool {
name := categoryName + ":" + RequestAddressHash(r)
// FNV-1a 32-bit
var h uint64 = 2166136261
for i := 0; i < len(name); i++ {
h ^= uint64(name[i])
h *= 16777619
}
// Calculate Usage
s := &RatelimitShards[h%RatelimitShardCount]
s.Lock()
now := time.Now()
e, ok := s.data[name]
if !ok || now.UnixNano() >= e.Reset {
e = RatelimitEntry{Reset: now.Add(period).UnixNano(), Usage: 1}
} else {
e.Usage++
}
s.data[name] = e
s.Unlock()
// Generate Headers
resetSecs := strconv.FormatFloat(float64(e.Reset-now.UnixNano())/float64(time.Second), 'f', 2, 64)
remaining := max(0, limit-e.Usage)
hdr := w.Header()
hdr.Set("X-Ratelimit-Category", categoryName)
hdr.Set("X-Ratelimit-Reset", resetSecs)
hdr.Set("X-Ratelimit-Limit", limitStr)
hdr.Set("X-Ratelimit-Remaining", strconv.Itoa(int(remaining)))
if e.Usage > int32(limit) {
SendClientError(w, r, ERROR_GENERIC_RATELIMIT)
return false
}
return true
}
}
// Create Gatekeeper
func NewGatekeeper(allowModerators bool) ChainMiddleware {
return func(w http.ResponseWriter, r *http.Request) bool {
if allowModerators {
var Count int
err := Database.QueryRow(r.Context(),
"SELECT 1 FROM gifuu.mod_key WHERE token_hash = $1",
RequestHash(r.Header.Get("Authorization")),
).Scan(
&Count,
)
if err != nil && err != pgx.ErrNoRows {
SendServerError(w, r, err)
return false
}
if err == pgx.ErrNoRows {
SendClientError(w, r, ERROR_GENERIC_UNAUTHORIZED)
return false
}
return true
}
return false
}
}
type ChainMiddleware func(w http.ResponseWriter, r *http.Request) bool
// Apply Middleware before Processing Request
func Chain(h http.HandlerFunc, wares ...ChainMiddleware) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
for _, mw := range wares {
if !mw(w, r) {
return
}
}
h(w, r)
}
}
type MethodHandler map[string]http.HandlerFunc
func (mh MethodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if handler, ok := mh[r.Method]; ok {
handler(w, r)
} else {
SendClientError(w, r, ERROR_GENERIC_METHOD_NOT_ALLOWED)
}
}