rc-1
This commit is contained in:
@@ -0,0 +1,82 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEventsUnsupported = errors.New("events unsupported")
|
||||
)
|
||||
|
||||
type EventHelper struct {
|
||||
m sync.Mutex
|
||||
c context.Context
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
f http.Flusher
|
||||
}
|
||||
|
||||
func NewEventHelper(ctx context.Context, w http.ResponseWriter, r *http.Request) (*EventHelper, error) {
|
||||
|
||||
// Setup Connection
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, ErrEventsUnsupported
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
sse := &EventHelper{f: flusher, c: ctx, w: w, r: r}
|
||||
|
||||
// Heartbeat Generator
|
||||
go func(h *EventHelper) {
|
||||
t := time.NewTicker(5 * time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-h.c.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
h.m.Lock()
|
||||
fmt.Fprintf(h.w, ": ping\n\n")
|
||||
h.f.Flush()
|
||||
h.m.Unlock()
|
||||
}
|
||||
}
|
||||
}(sse)
|
||||
|
||||
return sse, nil
|
||||
}
|
||||
|
||||
func (h *EventHelper) SendJSON(eventName string, eventData any) {
|
||||
|
||||
b, err := json.Marshal(map[string]any{
|
||||
"name": eventName,
|
||||
"data": eventData,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.m.Lock()
|
||||
fmt.Fprintf(h.w, "data: %s\n\n", b)
|
||||
h.f.Flush()
|
||||
h.m.Unlock()
|
||||
}
|
||||
|
||||
func (h *EventHelper) SendServerError(err error) {
|
||||
SendServerError(nil, h.r, err)
|
||||
h.SendClientError(ERROR_GENERIC_SERVER)
|
||||
}
|
||||
|
||||
func (h *EventHelper) SendClientError(err APIError) {
|
||||
h.SendJSON("error", err)
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type ChallengeSessionData struct {
|
||||
Expires int64
|
||||
Difficulty int
|
||||
}
|
||||
|
||||
type RatelimitEntry struct {
|
||||
Reset int64
|
||||
Usage int32
|
||||
}
|
||||
|
||||
type RatelimitShard struct {
|
||||
sync.Mutex
|
||||
data map[string]RatelimitEntry
|
||||
}
|
||||
|
||||
const (
|
||||
RatelimitShardCount = 256
|
||||
)
|
||||
|
||||
var (
|
||||
ChallengeAtomic sync.Mutex
|
||||
ChallengeSession = make(map[string]ChallengeSessionData, 1024)
|
||||
RatelimitShards = make([]RatelimitShard, RatelimitShardCount)
|
||||
)
|
||||
|
||||
func init() {
|
||||
for i := range RatelimitShards {
|
||||
RatelimitShards[i].data = make(map[string]RatelimitEntry, 128)
|
||||
}
|
||||
|
||||
// Cleanup Challenges
|
||||
go func() {
|
||||
t := time.NewTicker(10 * time.Minute)
|
||||
defer t.Stop()
|
||||
for range t.C {
|
||||
now := time.Now().UnixNano()
|
||||
ChallengeAtomic.Lock()
|
||||
for k, v := range ChallengeSession {
|
||||
if now > v.Expires {
|
||||
delete(ChallengeSession, k)
|
||||
}
|
||||
}
|
||||
ChallengeAtomic.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Cleanup Ratelimits
|
||||
go func() {
|
||||
t := time.NewTicker(10 * time.Minute)
|
||||
defer t.Stop()
|
||||
for range t.C {
|
||||
now := time.Now().UnixNano()
|
||||
for i := range RatelimitShards {
|
||||
s := &RatelimitShards[i]
|
||||
s.Lock()
|
||||
for k, v := range s.data {
|
||||
if now >= v.Reset {
|
||||
delete(s.data, k)
|
||||
}
|
||||
}
|
||||
s.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func NewChallenger(minimumDifficulty int) ChainMiddleware {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
|
||||
// --- Proof of Work ---
|
||||
var givenNonce = r.Header.Get("X-Challenge-Nonce")
|
||||
var givenCounterRaw = r.Header.Get("X-Challenge-Counter")
|
||||
var givenCounter = 0
|
||||
|
||||
if _, err := hex.DecodeString(givenNonce); err != nil {
|
||||
SendClientError(w, r, ERROR_BODY_INVALID_CHALLENGE)
|
||||
return false
|
||||
}
|
||||
if v, err := strconv.Atoi(givenCounterRaw); err != nil || v < 0 {
|
||||
SendClientError(w, r, ERROR_BODY_INVALID_CHALLENGE)
|
||||
return false
|
||||
} else {
|
||||
givenCounter = v
|
||||
}
|
||||
|
||||
// Consume Session
|
||||
ChallengeAtomic.Lock()
|
||||
session, exists := ChallengeSession[givenNonce]
|
||||
if !exists {
|
||||
ChallengeAtomic.Unlock()
|
||||
SendClientError(w, r, ERROR_UNKNOWN_CHALLENGE)
|
||||
return false
|
||||
}
|
||||
delete(ChallengeSession, givenNonce)
|
||||
ChallengeAtomic.Unlock()
|
||||
|
||||
if time.Now().Unix() >= session.Expires {
|
||||
SendClientError(w, r, ERROR_CHALLENGE_EXPIRED)
|
||||
return false
|
||||
}
|
||||
if session.Difficulty < minimumDifficulty {
|
||||
SendClientError(w, r, ERROR_CHALLENGE_TOO_EASY)
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate Results
|
||||
sessionInput := fmt.Sprintf("%s%d", givenNonce, givenCounter)
|
||||
sessionHash := sha256.Sum256([]byte(sessionInput))
|
||||
zeroBitsRequired := session.Difficulty
|
||||
zeroBitsFound := 0
|
||||
|
||||
for _, b := range sessionHash {
|
||||
if b == 0 {
|
||||
zeroBitsFound += 8
|
||||
} else {
|
||||
zeroBitsFound += bits.LeadingZeros8(b)
|
||||
break
|
||||
}
|
||||
}
|
||||
if zeroBitsFound < zeroBitsRequired {
|
||||
SendClientError(w, r, ERROR_CHALLENGE_INVALID)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Prevent Spam by Limiting Amount of Incoming Requests
|
||||
func NewRatelimiter(categoryName string, limit int32, period time.Duration) ChainMiddleware {
|
||||
limitStr := strconv.Itoa(int(limit))
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
name := categoryName + ":" + RequestAddressHash(r)
|
||||
|
||||
// FNV-1a 32-bit
|
||||
var h uint64 = 2166136261
|
||||
for i := 0; i < len(name); i++ {
|
||||
h ^= uint64(name[i])
|
||||
h *= 16777619
|
||||
}
|
||||
|
||||
// Calculate Usage
|
||||
s := &RatelimitShards[h%RatelimitShardCount]
|
||||
s.Lock()
|
||||
now := time.Now()
|
||||
e, ok := s.data[name]
|
||||
if !ok || now.UnixNano() >= e.Reset {
|
||||
e = RatelimitEntry{Reset: now.Add(period).UnixNano(), Usage: 1}
|
||||
} else {
|
||||
e.Usage++
|
||||
}
|
||||
s.data[name] = e
|
||||
s.Unlock()
|
||||
|
||||
// Generate Headers
|
||||
resetSecs := strconv.FormatFloat(float64(e.Reset-now.UnixNano())/float64(time.Second), 'f', 2, 64)
|
||||
remaining := max(0, limit-e.Usage)
|
||||
|
||||
hdr := w.Header()
|
||||
hdr.Set("X-Ratelimit-Category", categoryName)
|
||||
hdr.Set("X-Ratelimit-Reset", resetSecs)
|
||||
hdr.Set("X-Ratelimit-Limit", limitStr)
|
||||
hdr.Set("X-Ratelimit-Remaining", strconv.Itoa(int(remaining)))
|
||||
|
||||
if e.Usage > int32(limit) {
|
||||
SendClientError(w, r, ERROR_GENERIC_RATELIMIT)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Create Gatekeeper
|
||||
func NewGatekeeper(allowModerators bool) ChainMiddleware {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
|
||||
if allowModerators {
|
||||
var Count int
|
||||
err := Database.QueryRow(r.Context(),
|
||||
"SELECT 1 FROM gifuu.mod_key WHERE token_hash = $1",
|
||||
RequestHash(r.Header.Get("Authorization")),
|
||||
).Scan(
|
||||
&Count,
|
||||
)
|
||||
if err != nil && err != pgx.ErrNoRows {
|
||||
SendServerError(w, r, err)
|
||||
return false
|
||||
}
|
||||
if err == pgx.ErrNoRows {
|
||||
SendClientError(w, r, ERROR_GENERIC_UNAUTHORIZED)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type ChainMiddleware func(w http.ResponseWriter, r *http.Request) bool
|
||||
|
||||
// Apply Middleware before Processing Request
|
||||
func Chain(h http.HandlerFunc, wares ...ChainMiddleware) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, mw := range wares {
|
||||
if !mw(w, r) {
|
||||
return
|
||||
}
|
||||
}
|
||||
h(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
type MethodHandler map[string]http.HandlerFunc
|
||||
|
||||
func (mh MethodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if handler, ok := mh[r.Method]; ok {
|
||||
handler(w, r)
|
||||
} else {
|
||||
SendClientError(w, r, ERROR_GENERIC_METHOD_NOT_ALLOWED)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
snowflakeMutex sync.Mutex
|
||||
snowflakeMachineID int64
|
||||
snowflakeSequence int64
|
||||
snowflakeTimestamp int64
|
||||
|
||||
// Updates to normalizers here should be mirrored in `GET_Limits.go`!
|
||||
|
||||
RegexSpaces = regexp.MustCompile(`\s{2,}`)
|
||||
RegexUnderscores = regexp.MustCompile(`_+`)
|
||||
RegexNewlines = regexp.MustCompile(`\n{2,}`)
|
||||
RegexMatcherTitle = regexp.MustCompile(`^[\S\s]{1,80}$`)
|
||||
RegexMatcherTag = regexp.MustCompile(`^[\p{L}\p{N}_]{1,32}$`)
|
||||
RegexMatcherComment = regexp.MustCompile(`^[\S\s]{10,240}$`)
|
||||
)
|
||||
|
||||
func NormalizeTitle(str string) (string, bool) {
|
||||
if str == "" {
|
||||
return str, false
|
||||
}
|
||||
if !RegexMatcherTitle.MatchString(str) {
|
||||
return str, false
|
||||
}
|
||||
str = RegexSpaces.ReplaceAllString(str, " ")
|
||||
str = strings.TrimSpace(str)
|
||||
str = html.EscapeString(str)
|
||||
return str, true
|
||||
}
|
||||
|
||||
func NormalizeTag(str string) (string, bool) {
|
||||
if str == "" {
|
||||
return str, false
|
||||
}
|
||||
if !RegexMatcherTag.MatchString(str) {
|
||||
return str, false
|
||||
}
|
||||
str = RegexUnderscores.ReplaceAllString(str, "_")
|
||||
str = strings.Trim(str, "_")
|
||||
str = strings.ToUpper(str)
|
||||
return str, true
|
||||
}
|
||||
|
||||
func NormalizeComment(str string) (string, bool) {
|
||||
if str == "" {
|
||||
return str, false
|
||||
}
|
||||
if !RegexMatcherComment.MatchString(str) {
|
||||
return str, false
|
||||
}
|
||||
str = RegexNewlines.ReplaceAllString(str, " ")
|
||||
str = strings.TrimSpace(str)
|
||||
str = html.EscapeString(str)
|
||||
return str, true
|
||||
}
|
||||
|
||||
func ParseLimit(str string) int {
|
||||
v, _ := strconv.Atoi(str)
|
||||
return min(100, max(1, v))
|
||||
}
|
||||
|
||||
func ParseSnowflake(str string) int64 {
|
||||
v, err := strconv.ParseInt(str, 10, 64)
|
||||
if err != nil || v < 1 {
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func ParseJSON(r io.Reader, v any) error {
|
||||
l := io.LimitReader(r, int64(LIMIT_JSON))
|
||||
d := json.NewDecoder(l)
|
||||
d.DisallowUnknownFields()
|
||||
return d.Decode(v)
|
||||
}
|
||||
|
||||
// Generate a Unique Snowflake
|
||||
func RequestSnowflake() int64 {
|
||||
snowflakeMutex.Lock()
|
||||
defer snowflakeMutex.Unlock()
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
if now != snowflakeTimestamp {
|
||||
snowflakeSequence = 0
|
||||
} else {
|
||||
snowflakeSequence++
|
||||
if snowflakeSequence > SNOWFLAKE_MAX_SEQUENCE {
|
||||
for now <= snowflakeTimestamp {
|
||||
time.Sleep(time.Millisecond)
|
||||
now = time.Now().UnixMilli()
|
||||
}
|
||||
snowflakeSequence = 0
|
||||
}
|
||||
}
|
||||
|
||||
snowflakeTimestamp = now
|
||||
return ((now - SNOWFLAKE_EPOCH_MILLI) << 22) | (snowflakeMachineID << 12) | snowflakeSequence
|
||||
}
|
||||
|
||||
// Generate a Hex String out of random bytes
|
||||
func RequestToken() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||
panic("failed to generate enough random bytes")
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// Generate a SHA256 Hex String from a given string
|
||||
func RequestHash(str string) string {
|
||||
h := sha256.Sum256([]byte(str))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// Get the Client IP Address as a SHA256 Hex String
|
||||
func RequestAddressHash(r *http.Request) string {
|
||||
return RequestHash(RequestAddress(r))
|
||||
}
|
||||
|
||||
// Get the Client IP Address
|
||||
func RequestAddress(r *http.Request) string {
|
||||
if HTTP_PROXY != "" {
|
||||
return r.Header.Get(HTTP_PROXY)
|
||||
}
|
||||
addr, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return addr
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type APIError struct {
|
||||
Status int `json:"-"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
var (
|
||||
ERROR_GENERIC_SERVER = APIError{Status: 500, Code: 10000, Message: "Server Error"}
|
||||
ERROR_GENERIC_NOT_FOUND = APIError{Status: 404, Code: 10001, Message: "Endpoint Not Found"}
|
||||
ERROR_GENERIC_RATELIMIT = APIError{Status: 429, Code: 10002, Message: "Too Many Requests"}
|
||||
ERROR_GENERIC_UNAUTHORIZED = APIError{Status: 401, Code: 10003, Message: "Unauthorized"}
|
||||
ERROR_GENERIC_FORBIDDEN = APIError{Status: 403, Code: 10004, Message: "Forbidden"}
|
||||
ERROR_GENERIC_METHOD_NOT_ALLOWED = APIError{Status: 405, Code: 10005, Message: "Method Not Allowed"}
|
||||
ERROR_SERVER_RESOURCES_EXHAUSTED = APIError{Status: 507, Code: 11006, Message: "Resources Exhausted"}
|
||||
ERROR_BODY_EMPTY = APIError{Status: 411, Code: 12000, Message: "Request Body is Empty"}
|
||||
ERROR_BODY_TOO_LARGE = APIError{Status: 413, Code: 12001, Message: "Request Body is Too Large"}
|
||||
ERROR_BODY_INVALID_CONTENT_TYPE = APIError{Status: 400, Code: 12002, Message: "Invalid 'Content-Type' Header"}
|
||||
ERROR_BODY_INVALID_CHALLENGE = APIError{Status: 400, Code: 12003, Message: "Invalid 'X-Challenge-*' Header"}
|
||||
ERROR_BODY_INVALID_FIELD = APIError{Status: 400, Code: 12004, Message: "Invalid Body Field"}
|
||||
ERROR_BODY_INVALID_DATA = APIError{Status: 422, Code: 12005, Message: "Invalid Body"}
|
||||
ERROR_UNKNOWN_ENDPOINT = APIError{Status: 404, Code: 13000, Message: "Unknown Endpoint"}
|
||||
ERROR_UNKNOWN_FUNCTION = APIError{Status: 404, Code: 13001, Message: "Unknown Function"}
|
||||
ERROR_UNKNOWN_ANIMATION = APIError{Status: 404, Code: 13002, Message: "Unknown Animation"}
|
||||
ERROR_UNKNOWN_TASK = APIError{Status: 404, Code: 13003, Message: "Unknown Task"}
|
||||
ERROR_UNKNOWN_CHALLENGE = APIError{Status: 404, Code: 13004, Message: "Unknown Challenge"}
|
||||
ERROR_CHALLENGE_INVALID = APIError{Status: 400, Code: 14000, Message: "Invalid Challenge Result"}
|
||||
ERROR_CHALLENGE_TOO_EASY = APIError{Status: 400, Code: 14001, Message: "Challenge Difficulty Too Low"}
|
||||
ERROR_CHALLENGE_EXPIRED = APIError{Status: 400, Code: 14002, Message: "Challenge Expired"}
|
||||
ERROR_MEDIA_INVALID = APIError{Status: 400, Code: 15000, Message: "Media Invalid"}
|
||||
ERROR_MEDIA_INAPPROPRIATE = APIError{Status: 400, Code: 15001, Message: "Media Inappropriate"}
|
||||
)
|
||||
|
||||
// Reject request due to a Client Mistake
|
||||
func SendClientError(w http.ResponseWriter, r *http.Request, err APIError) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(err.Status)
|
||||
fmt.Fprintf(w, `{"code":%d,"message":%q}`, err.Code, err.Message)
|
||||
}
|
||||
|
||||
// Reject request due to a Server Error
|
||||
// Additionally collects debug information and logs it to the console
|
||||
func SendServerError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
|
||||
debugStack := strings.Split(string(debug.Stack()), "\n")
|
||||
for i, item := range debugStack {
|
||||
debugStack[i] = strings.ReplaceAll(item, "\t", " ")
|
||||
}
|
||||
if len(debugStack) > 5 {
|
||||
debugStack = debugStack[5:] // skip header
|
||||
}
|
||||
|
||||
reqHeader := make(map[string]string, len(r.Header))
|
||||
for key, header := range r.Header {
|
||||
reqHeader[key] = strings.Join(header, ", ")
|
||||
}
|
||||
|
||||
LoggerHTTP.Data(ERROR, err.Error(), map[string]any{
|
||||
"request": map[string]any{
|
||||
"method": r.Method,
|
||||
"url": r.URL.String(),
|
||||
"headers": reqHeader,
|
||||
},
|
||||
"error": map[string]any{
|
||||
"raw": err,
|
||||
"message": err.Error(),
|
||||
"stack": debugStack,
|
||||
},
|
||||
})
|
||||
|
||||
if w != nil {
|
||||
SendClientError(w, r, ERROR_GENERIC_SERVER)
|
||||
}
|
||||
}
|
||||
|
||||
// Respond to the request with a JSON object
|
||||
func SendJSON(w http.ResponseWriter, r *http.Request, statusCode int, responseObject any) (int, error) {
|
||||
|
||||
// Check Compression
|
||||
var g io.Writer
|
||||
if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
z := gzip.NewWriter(w)
|
||||
defer z.Close()
|
||||
g = z
|
||||
} else {
|
||||
g = w
|
||||
}
|
||||
|
||||
// Stream Object
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
if b, ok := responseObject.([]byte); ok {
|
||||
return g.Write(b)
|
||||
} else {
|
||||
j := json.NewEncoder(g)
|
||||
return 0, j.Encode(responseObject)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode Object as JSON and gzipped version
|
||||
func PrepareStaticJSON(responseObject any) ([]byte, []byte, error) {
|
||||
|
||||
// Encode Object
|
||||
buf, err := json.Marshal(responseObject)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Compress Object
|
||||
cmp := bytes.Buffer{}
|
||||
zip := gzip.NewWriter(&cmp)
|
||||
|
||||
if _, err := zip.Write(buf); err != nil {
|
||||
zip.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := zip.Close(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return buf, cmp.Bytes(), nil
|
||||
}
|
||||
|
||||
// Respond to the request with a Static JSON Object
|
||||
func SendStaticJSON(w http.ResponseWriter, r *http.Request, statusCode int, content []byte, gzipped []byte) (int, error) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if gzipped != nil && strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
return w.Write(gzipped)
|
||||
}
|
||||
return w.Write(content)
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var Database *pgxpool.Pool
|
||||
|
||||
func SetupDatabase(stop context.Context, await *sync.WaitGroup) {
|
||||
|
||||
var err error
|
||||
t := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create and Test Client
|
||||
cfg, err := pgxpool.ParseConfig(DATABASE_URL)
|
||||
if err != nil {
|
||||
LoggerDatabase.Log(FATAL, "Invalid Database URI: %s", err)
|
||||
return
|
||||
}
|
||||
if Database, err = pgxpool.NewWithConfig(ctx, cfg); err != nil {
|
||||
LoggerDatabase.Log(FATAL, "Failed to create pool: %s", err.Error())
|
||||
return
|
||||
}
|
||||
if err = Database.Ping(ctx); err != nil {
|
||||
LoggerDatabase.Log(FATAL, "Failed to ping database: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Shutdown Logic
|
||||
await.Add(1)
|
||||
go func() {
|
||||
defer await.Done()
|
||||
<-stop.Done()
|
||||
Database.Close()
|
||||
LoggerDatabase.Log(INFO, "Closed")
|
||||
}()
|
||||
|
||||
LoggerDatabase.Log(INFO, "Ready in %s", time.Since(t))
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LoggerSeverity string
|
||||
|
||||
const (
|
||||
INFO LoggerSeverity = "INFO" // This is a basic informational alert.
|
||||
WARN LoggerSeverity = "WARN" // This is a warning, meaning the program has recovered from an error.
|
||||
DEBUG LoggerSeverity = "DEBUG" // This is a detailed alert containing information used for debugging.
|
||||
ERROR LoggerSeverity = "ERROR" // An error has occurred, please advise.
|
||||
FATAL LoggerSeverity = "FATAL" // An irrecoverable error has occured and the program must exit immediately.
|
||||
)
|
||||
|
||||
var (
|
||||
LoggerInit = &LoggerInstance{source: "INIT"}
|
||||
LoggerHTTP = &LoggerInstance{source: "HTTP"}
|
||||
LoggerModel = &LoggerInstance{source: "ONNX"}
|
||||
LoggerStorage = &LoggerInstance{source: "DISK"}
|
||||
LoggerDatabase = &LoggerInstance{source: "RMDB"}
|
||||
)
|
||||
|
||||
type LoggerInstance struct {
|
||||
source string
|
||||
}
|
||||
|
||||
func (p *LoggerInstance) entry(severity LoggerSeverity, source, message string) {
|
||||
target := os.Stdout
|
||||
if severity == ERROR || severity == FATAL {
|
||||
target = os.Stderr
|
||||
}
|
||||
fmt.Fprintf(target, "%s [%s] [%s] %s\n", time.Now().Format(time.DateTime), severity, source, message)
|
||||
}
|
||||
|
||||
func (p *LoggerInstance) Log(severity LoggerSeverity, format string, a ...any) {
|
||||
p.entry(severity, p.source, fmt.Sprintf(format, a...))
|
||||
if severity == FATAL {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *LoggerInstance) Data(severity LoggerSeverity, message string, data any) {
|
||||
if data == nil {
|
||||
p.entry(severity, p.source, message)
|
||||
} else {
|
||||
entryData := ""
|
||||
if b, err := json.MarshalIndent(data, "", " "); err != nil {
|
||||
entryData = fmt.Sprintf("marshal_error: %q", err)
|
||||
} else {
|
||||
entryData = string(b)
|
||||
}
|
||||
p.entry(severity, p.source, fmt.Sprintf("%s\n%s\n---", message, entryData))
|
||||
}
|
||||
if severity == FATAL {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gifuu/include"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
onnx "github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
const (
|
||||
MODEL_THRESHOLD_DENY = 0.95
|
||||
MODEL_THRESHOLD_HIDE = 0.75
|
||||
MODEL_SIZE = 224
|
||||
MODEL_FRAMERATE = 3
|
||||
)
|
||||
|
||||
var onnxSession *onnx.DynamicAdvancedSession
|
||||
|
||||
type ClassifyResults struct {
|
||||
Drawing float32
|
||||
Hentai float32
|
||||
Neutral float32
|
||||
Porn float32
|
||||
Sexy float32
|
||||
}
|
||||
|
||||
func SetupModel(stop context.Context, await *sync.WaitGroup) {
|
||||
if ONNX_RUNTIME_PATH == "" {
|
||||
LoggerModel.Log(WARN, "Set runtime path with envvar ONNX_RUNTIME_PATH to enable model")
|
||||
return
|
||||
}
|
||||
t := time.Now()
|
||||
|
||||
// Initialize Environment
|
||||
onnx.SetSharedLibraryPath(ONNX_RUNTIME_PATH)
|
||||
if err := onnx.InitializeEnvironment(); err != nil {
|
||||
LoggerModel.Log(FATAL, "Failed to initialize ONNX Runtime: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize Settings
|
||||
options, err := onnx.NewSessionOptions()
|
||||
if err != nil {
|
||||
LoggerModel.Log(FATAL, "Failed to create session options: %s", err)
|
||||
return
|
||||
}
|
||||
defer options.Destroy()
|
||||
|
||||
if ONNX_RUNTIME_CUDA {
|
||||
cudaOptions, err := onnx.NewCUDAProviderOptions()
|
||||
if err != nil {
|
||||
LoggerModel.Log(WARN, "CUDA unavailable, falling back to CPU: %s", err)
|
||||
} else {
|
||||
defer cudaOptions.Destroy()
|
||||
cudaOptions.Update(map[string]string{
|
||||
"cudnn_conv_algo_search": "DEFAULT", // use the only working frontend
|
||||
})
|
||||
options.AppendExecutionProviderCUDA(cudaOptions)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize Model
|
||||
session, err := onnx.NewDynamicAdvancedSessionWithONNXData(
|
||||
include.MODEL_NSFW,
|
||||
[]string{"input"},
|
||||
[]string{"prediction"},
|
||||
options,
|
||||
)
|
||||
if err != nil {
|
||||
LoggerModel.Log(FATAL, "Failed to load model: %s", err)
|
||||
return
|
||||
}
|
||||
onnxSession = session
|
||||
|
||||
// Test Model with Dummy Data
|
||||
dummy := make([]float32, MODEL_SIZE*MODEL_SIZE*3)
|
||||
if _, err := ModelClassifyTensorBatch(dummy, 1); err != nil {
|
||||
LoggerModel.Log(FATAL, "Failed to initialize model: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
await.Add(1)
|
||||
go func() {
|
||||
defer await.Done()
|
||||
<-stop.Done()
|
||||
onnxSession.Destroy()
|
||||
onnxSession = nil
|
||||
onnx.DestroyEnvironment()
|
||||
LoggerModel.Log(INFO, "Closed")
|
||||
}()
|
||||
|
||||
LoggerModel.Log(INFO, "Model ready in %s", time.Since(t))
|
||||
}
|
||||
|
||||
func ModelClassifyTensorBatch(data []float32, count int) ([]ClassifyResults, error) {
|
||||
|
||||
// Model is disabled, generate some dummy results.
|
||||
if onnxSession == nil {
|
||||
results := make([]ClassifyResults, count)
|
||||
for i := 0; i < count; i++ {
|
||||
results = append(results, ClassifyResults{Neutral: 1})
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
inputTensor, err := onnx.NewTensor(
|
||||
onnx.NewShape(int64(count), MODEL_SIZE, MODEL_SIZE, 3),
|
||||
data,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer inputTensor.Destroy()
|
||||
|
||||
outputs := []onnx.ArbitraryTensor{nil}
|
||||
if err := onnxSession.Run([]onnx.ArbitraryTensor{inputTensor}, outputs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer outputs[0].Destroy()
|
||||
|
||||
raw := outputs[0].(*onnx.Tensor[float32]).GetData()
|
||||
results := make([]ClassifyResults, count)
|
||||
|
||||
for i := range results {
|
||||
base := i * 5
|
||||
results[i] = ClassifyResults{
|
||||
Drawing: raw[base+0],
|
||||
Hentai: raw[base+1],
|
||||
Neutral: raw[base+2],
|
||||
Porn: raw[base+3],
|
||||
Sexy: raw[base+4],
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
const (
|
||||
SNOWFLAKE_MAX_MACHINE_ID int64 = (1 << 10) - 1
|
||||
SNOWFLAKE_MAX_SEQUENCE int64 = (1 << 12) - 1
|
||||
SNOWFLAKE_EPOCH_MILLI = 1207008000000
|
||||
SNOWFLAKE_EPOCH_SECONDS = SNOWFLAKE_EPOCH_MILLI / 1000 // Apr 1st 2008 (Teto B-Day!)
|
||||
TIMEOUT_SHUTDOWN = 1 * time.Minute // Standard Timeout for Shutdowns
|
||||
TIMEOUT_CONTEXT = 10 * time.Second // Standard Timeout for Requests
|
||||
FILE_PUBLIC = os.FileMode(0770) // rwxrwx---
|
||||
FILE_PRIVATE = os.FileMode(0700) // rwx------
|
||||
)
|
||||
|
||||
var (
|
||||
TEMP_CAPACITY atomic.Int64
|
||||
LIMIT_JSON = EnvNumber("LIMIT_JSON", 8*1024) // ( 8KB) Size limit per incoming JSON string
|
||||
LIMIT_FILE = EnvNumber("LIMIT_FILE", 25*1024*1024) // (25MB) Size limit per incoming media file
|
||||
LIMIT_TEMP = EnvNumber("LIMIT_TEMP", 2*1024*1024*1024) // ( 2GB) Disk space allowed for temporary files
|
||||
LIMIT_ENCODES = EnvNumber("LIMIT_ENCODES", 1) // Concurrent Uploads
|
||||
LIMIT_PROBES = EnvNumber("LIMIT_PROBES", 4) // Concurrent Probes
|
||||
LIMIT_MIME_TYPE = EnvSlice("LIMIT_MIME_TYPE", ",", []string{
|
||||
/* STANDARD */ "image/jpeg", "image/png", "image/gif", "image/webp", "image/heic", "image/heif",
|
||||
/* FUTURE */ "image/avif", "image/jxl",
|
||||
/* LEGACY */ "image/tiff", "image/bmp",
|
||||
/* STANDARD */ "video/mp4", "video/webm", "video/quicktime", "video/x-matroska",
|
||||
/* LEGACY */ "video/avi", "video/x-ms-wmv",
|
||||
})
|
||||
TEMPLATE_BASE_WEB = EnvString("TEMPLATE_BASE_WEB", "http://localhost:5173")
|
||||
TEMPLATE_BASE_CDN = EnvString("TEMPLATE_BASE_CDN", "http://localhost:3000")
|
||||
TEMPLATE_BASE_API = EnvString("TEMPLATE_BASE_API", "http://localhost:8080")
|
||||
MACHINE_ID = EnvString("MACHINE_ID", "0")
|
||||
MACHINE_HOSTNAME = EnvString("MACHINE_HOSTNAME", "le fishe")
|
||||
MACHINE_PROVERB = EnvString("MACHINE_PROVERB", "><> .o( blub blub)")
|
||||
DATABASE_URL = EnvString("DATABASE_URL", "postgresql://postgres:password@localhost:5432")
|
||||
STORAGE_DISK_TEMP = EnvString("STORAGE_DISK_TEMP", "_temp")
|
||||
STORAGE_DISK_PUBLIC = EnvString("STORAGE_DISK_PUBLIC", "_public")
|
||||
ONNX_RUNTIME_PATH = EnvString("ONNX_RUNTIME_PATH", "")
|
||||
ONNX_RUNTIME_CUDA = EnvString("ONNX_RUNTIME_CUDA", "") != ""
|
||||
HTTP_ADDRESS = EnvString("HTTP_ADDRESS", "127.0.0.1:8080")
|
||||
HTTP_PROXY = EnvString("HTTP_PROXY", "")
|
||||
)
|
||||
|
||||
var (
|
||||
SEMA_UPLOADS = semaphore.NewWeighted(int64(LIMIT_TEMP))
|
||||
SEMA_ENCODES = semaphore.NewWeighted(int64(LIMIT_ENCODES))
|
||||
SEMA_PROBES = semaphore.NewWeighted(int64(LIMIT_PROBES))
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Prepare Directories
|
||||
if err := os.MkdirAll(STORAGE_DISK_PUBLIC, FILE_PUBLIC); err != nil {
|
||||
LoggerInit.Log(FATAL, "Cannot Create Public Directory")
|
||||
return
|
||||
}
|
||||
if err := os.MkdirAll(STORAGE_DISK_TEMP, FILE_PUBLIC); err != nil {
|
||||
LoggerInit.Log(FATAL, "Cannot Create Temp Directory")
|
||||
return
|
||||
}
|
||||
|
||||
// Check Executables
|
||||
if err := exec.Command("ffmpeg", "--help").Run(); err != nil {
|
||||
LoggerInit.Log(FATAL, "FFmpeg failed to start: %s", err)
|
||||
return
|
||||
}
|
||||
if err := exec.Command("ffprobe", "--help").Run(); err != nil {
|
||||
LoggerInit.Log(FATAL, "FFprobe failed to start: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Read String from Environment
|
||||
func EnvString(field, initial string) string {
|
||||
if value := os.Getenv(field); value == "" {
|
||||
return initial
|
||||
} else {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// Read String from Environment and Parse it as a number
|
||||
func EnvNumber(field string, initial int) int {
|
||||
if value := os.Getenv(field); value == "" {
|
||||
return initial
|
||||
} else if number, err := strconv.Atoi(value); err != nil {
|
||||
return initial
|
||||
} else {
|
||||
return number
|
||||
}
|
||||
}
|
||||
|
||||
// Read String from Environment and Parse it as a slice using the given delimiter
|
||||
func EnvSlice(field, delimiter string, initial []string) []string {
|
||||
if value := os.Getenv(field); value == "" {
|
||||
return initial
|
||||
} else {
|
||||
return strings.Split(value, delimiter)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type StringInteger float64
|
||||
|
||||
func (d *StringInteger) UnmarshalJSON(data []byte) error {
|
||||
s := strings.Trim(string(data), `"`)
|
||||
v, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*d = StringInteger(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
type StringFloat float64
|
||||
|
||||
func (d *StringFloat) UnmarshalJSON(data []byte) error {
|
||||
s := strings.Trim(string(data), `"`)
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*d = StringFloat(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
type StringFramerate float64
|
||||
|
||||
func (d *StringFramerate) UnmarshalJSON(data []byte) error {
|
||||
s := strings.Trim(string(data), `"`)
|
||||
|
||||
if parts := strings.SplitN(s, "/", 2); len(parts) == 2 {
|
||||
|
||||
num, err := strconv.ParseFloat(parts[0], 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid numerator: %s", err)
|
||||
}
|
||||
den, err := strconv.ParseFloat(parts[1], 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid denominator: %s", err)
|
||||
}
|
||||
|
||||
if den == 0 {
|
||||
den = 1
|
||||
}
|
||||
|
||||
*d = StringFramerate(num / den)
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse framerate")
|
||||
}
|
||||
*d = StringFramerate(f)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ProbeStream struct {
|
||||
Index int `json:"index"` // 0
|
||||
CodecType string `json:"codec_type"` // video
|
||||
Width int `json:"width"` // 1920
|
||||
Height int `json:"height"` // 1080
|
||||
NumberFrames StringInteger `json:"nb_frames"` // 1
|
||||
RFrameRate StringFramerate `json:"r_frame_rate"` // 15/1
|
||||
Duration StringFloat `json:"duration"` // 251.800
|
||||
}
|
||||
|
||||
type ProbeResults struct {
|
||||
Streams []ProbeStream `json:"streams"`
|
||||
}
|
||||
Reference in New Issue
Block a user