238 lines
5.2 KiB
Go
238 lines
5.2 KiB
Go
|
|
package tools
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/sha256"
|
||
|
|
"encoding/hex"
|
||
|
|
"fmt"
|
||
|
|
"math/bits"
|
||
|
|
"net/http"
|
||
|
|
"strconv"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/jackc/pgx/v5"
|
||
|
|
)
|
||
|
|
|
||
|
|
type ChallengeSessionData struct {
|
||
|
|
Expires int64
|
||
|
|
Difficulty int
|
||
|
|
}
|
||
|
|
|
||
|
|
type RatelimitEntry struct {
|
||
|
|
Reset int64
|
||
|
|
Usage int32
|
||
|
|
}
|
||
|
|
|
||
|
|
type RatelimitShard struct {
|
||
|
|
sync.Mutex
|
||
|
|
data map[string]RatelimitEntry
|
||
|
|
}
|
||
|
|
|
||
|
|
const (
|
||
|
|
RatelimitShardCount = 256
|
||
|
|
)
|
||
|
|
|
||
|
|
var (
|
||
|
|
ChallengeAtomic sync.Mutex
|
||
|
|
ChallengeSession = make(map[string]ChallengeSessionData, 1024)
|
||
|
|
RatelimitShards = make([]RatelimitShard, RatelimitShardCount)
|
||
|
|
)
|
||
|
|
|
||
|
|
func init() {
|
||
|
|
for i := range RatelimitShards {
|
||
|
|
RatelimitShards[i].data = make(map[string]RatelimitEntry, 128)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Cleanup Challenges
|
||
|
|
go func() {
|
||
|
|
t := time.NewTicker(10 * time.Minute)
|
||
|
|
defer t.Stop()
|
||
|
|
for range t.C {
|
||
|
|
now := time.Now().UnixNano()
|
||
|
|
ChallengeAtomic.Lock()
|
||
|
|
for k, v := range ChallengeSession {
|
||
|
|
if now > v.Expires {
|
||
|
|
delete(ChallengeSession, k)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
ChallengeAtomic.Unlock()
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
// Cleanup Ratelimits
|
||
|
|
go func() {
|
||
|
|
t := time.NewTicker(10 * time.Minute)
|
||
|
|
defer t.Stop()
|
||
|
|
for range t.C {
|
||
|
|
now := time.Now().UnixNano()
|
||
|
|
for i := range RatelimitShards {
|
||
|
|
s := &RatelimitShards[i]
|
||
|
|
s.Lock()
|
||
|
|
for k, v := range s.data {
|
||
|
|
if now >= v.Reset {
|
||
|
|
delete(s.data, k)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
s.Unlock()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewChallenger(minimumDifficulty int) ChainMiddleware {
|
||
|
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||
|
|
|
||
|
|
// --- Proof of Work ---
|
||
|
|
var givenNonce = r.Header.Get("X-Challenge-Nonce")
|
||
|
|
var givenCounterRaw = r.Header.Get("X-Challenge-Counter")
|
||
|
|
var givenCounter = 0
|
||
|
|
|
||
|
|
if _, err := hex.DecodeString(givenNonce); err != nil {
|
||
|
|
SendClientError(w, r, ERROR_BODY_INVALID_CHALLENGE)
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
if v, err := strconv.Atoi(givenCounterRaw); err != nil || v < 0 {
|
||
|
|
SendClientError(w, r, ERROR_BODY_INVALID_CHALLENGE)
|
||
|
|
return false
|
||
|
|
} else {
|
||
|
|
givenCounter = v
|
||
|
|
}
|
||
|
|
|
||
|
|
// Consume Session
|
||
|
|
ChallengeAtomic.Lock()
|
||
|
|
session, exists := ChallengeSession[givenNonce]
|
||
|
|
if !exists {
|
||
|
|
ChallengeAtomic.Unlock()
|
||
|
|
SendClientError(w, r, ERROR_UNKNOWN_CHALLENGE)
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
delete(ChallengeSession, givenNonce)
|
||
|
|
ChallengeAtomic.Unlock()
|
||
|
|
|
||
|
|
if time.Now().Unix() >= session.Expires {
|
||
|
|
SendClientError(w, r, ERROR_CHALLENGE_EXPIRED)
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
if session.Difficulty < minimumDifficulty {
|
||
|
|
SendClientError(w, r, ERROR_CHALLENGE_TOO_EASY)
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
// Validate Results
|
||
|
|
sessionInput := fmt.Sprintf("%s%d", givenNonce, givenCounter)
|
||
|
|
sessionHash := sha256.Sum256([]byte(sessionInput))
|
||
|
|
zeroBitsRequired := session.Difficulty
|
||
|
|
zeroBitsFound := 0
|
||
|
|
|
||
|
|
for _, b := range sessionHash {
|
||
|
|
if b == 0 {
|
||
|
|
zeroBitsFound += 8
|
||
|
|
} else {
|
||
|
|
zeroBitsFound += bits.LeadingZeros8(b)
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if zeroBitsFound < zeroBitsRequired {
|
||
|
|
SendClientError(w, r, ERROR_CHALLENGE_INVALID)
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Prevent Spam by Limiting Amount of Incoming Requests
|
||
|
|
func NewRatelimiter(categoryName string, limit int32, period time.Duration) ChainMiddleware {
|
||
|
|
limitStr := strconv.Itoa(int(limit))
|
||
|
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||
|
|
name := categoryName + ":" + RequestAddressHash(r)
|
||
|
|
|
||
|
|
// FNV-1a 32-bit
|
||
|
|
var h uint64 = 2166136261
|
||
|
|
for i := 0; i < len(name); i++ {
|
||
|
|
h ^= uint64(name[i])
|
||
|
|
h *= 16777619
|
||
|
|
}
|
||
|
|
|
||
|
|
// Calculate Usage
|
||
|
|
s := &RatelimitShards[h%RatelimitShardCount]
|
||
|
|
s.Lock()
|
||
|
|
now := time.Now()
|
||
|
|
e, ok := s.data[name]
|
||
|
|
if !ok || now.UnixNano() >= e.Reset {
|
||
|
|
e = RatelimitEntry{Reset: now.Add(period).UnixNano(), Usage: 1}
|
||
|
|
} else {
|
||
|
|
e.Usage++
|
||
|
|
}
|
||
|
|
s.data[name] = e
|
||
|
|
s.Unlock()
|
||
|
|
|
||
|
|
// Generate Headers
|
||
|
|
resetSecs := strconv.FormatFloat(float64(e.Reset-now.UnixNano())/float64(time.Second), 'f', 2, 64)
|
||
|
|
remaining := max(0, limit-e.Usage)
|
||
|
|
|
||
|
|
hdr := w.Header()
|
||
|
|
hdr.Set("X-Ratelimit-Category", categoryName)
|
||
|
|
hdr.Set("X-Ratelimit-Reset", resetSecs)
|
||
|
|
hdr.Set("X-Ratelimit-Limit", limitStr)
|
||
|
|
hdr.Set("X-Ratelimit-Remaining", strconv.Itoa(int(remaining)))
|
||
|
|
|
||
|
|
if e.Usage > int32(limit) {
|
||
|
|
SendClientError(w, r, ERROR_GENERIC_RATELIMIT)
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create Gatekeeper
|
||
|
|
func NewGatekeeper(allowModerators bool) ChainMiddleware {
|
||
|
|
return func(w http.ResponseWriter, r *http.Request) bool {
|
||
|
|
|
||
|
|
if allowModerators {
|
||
|
|
var Count int
|
||
|
|
err := Database.QueryRow(r.Context(),
|
||
|
|
"SELECT 1 FROM gifuu.mod_key WHERE token_hash = $1",
|
||
|
|
RequestHash(r.Header.Get("Authorization")),
|
||
|
|
).Scan(
|
||
|
|
&Count,
|
||
|
|
)
|
||
|
|
if err != nil && err != pgx.ErrNoRows {
|
||
|
|
SendServerError(w, r, err)
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
if err == pgx.ErrNoRows {
|
||
|
|
SendClientError(w, r, ERROR_GENERIC_UNAUTHORIZED)
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
type ChainMiddleware func(w http.ResponseWriter, r *http.Request) bool
|
||
|
|
|
||
|
|
// Apply Middleware before Processing Request
|
||
|
|
func Chain(h http.HandlerFunc, wares ...ChainMiddleware) http.HandlerFunc {
|
||
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
for _, mw := range wares {
|
||
|
|
if !mw(w, r) {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
}
|
||
|
|
h(w, r)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
type MethodHandler map[string]http.HandlerFunc
|
||
|
|
|
||
|
|
func (mh MethodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||
|
|
if handler, ok := mh[r.Method]; ok {
|
||
|
|
handler(w, r)
|
||
|
|
} else {
|
||
|
|
SendClientError(w, r, ERROR_GENERIC_METHOD_NOT_ALLOWED)
|
||
|
|
}
|
||
|
|
}
|