Files

441 lines
11 KiB
Go
Raw Permalink Normal View History

2026-05-23 17:22:03 -07:00
package main
import (
"bytes"
"context"
"crypto/md5"
"crypto/sha256"
_ "embed"
"encoding/json"
"fmt"
"image"
"image/gif"
"image/jpeg"
"image/png"
"io"
"log"
"math"
"net/http"
"os"
"os/signal"
"strconv"
"sync"
"syscall"
"time"
onnx "github.com/yalue/onnxruntime_go"
"golang.org/x/image/draw"
"golang.org/x/image/webp"
"golang.org/x/sync/semaphore"
)
type ImageType string
const (
IMAGE_OTHER ImageType = "UNKNOWN"
IMAGE_WEBP ImageType = "WEBP"
IMAGE_JPEG ImageType = "JPG"
IMAGE_PNG ImageType = "PNG"
IMAGE_GIF ImageType = "GIF"
MODEL_SIZE = 224
)
var (
//go:embed nsfw.onnx
ONNX_MODEL []byte
ONNX_SESSION *onnx.DynamicAdvancedSession
ONNX_RUNTIME_PATH = os.Getenv("ONNX_RUNTIME_PATH")
ONNX_RUNTIME_CUDA = os.Getenv("ONNX_RUNTIME_CUDA") != ""
HTTP_ADDRESS = os.Getenv("HTTP_ADDRESS")
HTTP_SEMAPHORE = semaphore.NewWeighted(16)
HTTP_MAX_BODY_BYTES int64 = 16 * 1024 * 1024 // 16mb
MODEL_THRESHOLD float32 = 0.7
)
func init() {
if HTTP_ADDRESS == "" {
HTTP_ADDRESS = "127.0.0.1:9000"
}
if plaintext := os.Getenv("HTTP_CONCURRENCY"); plaintext != "" {
val, err := strconv.ParseInt(plaintext, 10, 64)
if err == nil {
HTTP_SEMAPHORE = semaphore.NewWeighted(val)
}
}
if plaintext := os.Getenv("HTTP_MAX_BODY_BYTES"); plaintext != "" {
val, err := strconv.ParseInt(plaintext, 10, 64)
if err == nil {
HTTP_MAX_BODY_BYTES = val
}
}
if plaintext := os.Getenv("MODEL_THRESHOLD"); plaintext != "" {
val, err := strconv.ParseFloat(plaintext, 32)
if err == nil {
MODEL_THRESHOLD = float32(val)
}
}
}
type ClassifyResults struct {
Rating float32
Drawing float32
Hentai float32
Neutral float32
Porn float32
Sexy float32
}
func main() {
time.Local = time.UTC
// Startup Services
var stopCtx, stop = context.WithCancel(context.Background())
var stopWg sync.WaitGroup
SetupModel(stopCtx, &stopWg)
go SetupHTTP(stopCtx, &stopWg)
// Await Shutdown Signal
cancel := make(chan os.Signal, 1)
signal.Notify(cancel, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
<-cancel
stop()
// Begin Shutdown Process
timeout, finish := context.WithTimeout(context.Background(), time.Minute)
defer finish()
go func() {
<-timeout.Done()
if timeout.Err() == context.DeadlineExceeded {
log.Fatalln("[MAIN] Shutdown Deadline Exceeded")
}
}()
stopWg.Wait()
os.Exit(0)
}
func SetupHTTP(stop context.Context, await *sync.WaitGroup) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
var elapsedProbe, elapsedDecode, elapsedClassify int64
t := time.Now()
switch r.Method {
case http.MethodPost:
// Request Validation
if r.ContentLength > HTTP_MAX_BODY_BYTES {
http.Error(w, "Content Too Large", http.StatusRequestEntityTooLarge)
return
}
r.Body = http.MaxBytesReader(w, r.Body, HTTP_MAX_BODY_BYTES)
var d []byte
if raw, err := io.ReadAll(r.Body); err != nil {
http.Error(w, "Body Error", http.StatusLengthRequired)
return
} else {
d = raw
}
// Image Validation
var imageType ImageType
var imageInfo image.Config
var imageData image.Image
var imageErr error
switch { // Sniffing
case len(d) > 3 && // JPEG
d[0] == 0xFF && d[1] == 0xD8 && d[2] == 0xFF:
imageType = IMAGE_JPEG
case len(d) > 8 && // PNG
d[0] == 0x89 && d[1] == 0x50 && d[2] == 0x4E && d[3] == 0x47 &&
d[4] == 0x0D && d[5] == 0x0A && d[6] == 0x1A && d[7] == 0x0A:
imageType = IMAGE_PNG
case len(d) > 4 && // GIF
d[0] == 0x47 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x38:
imageType = IMAGE_GIF
case len(d) > 12 && // WEBP
d[0] == 0x52 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x46 &&
d[8] == 0x57 && d[9] == 0x45 && d[10] == 0x42 && d[11] == 0x50:
imageType = IMAGE_WEBP
default:
imageType = IMAGE_OTHER
}
switch imageType { // Decode Header
case IMAGE_WEBP:
imageInfo, imageErr = webp.DecodeConfig(bytes.NewReader(d))
case IMAGE_JPEG:
imageInfo, imageErr = jpeg.DecodeConfig(bytes.NewReader(d))
case IMAGE_PNG:
imageInfo, imageErr = png.DecodeConfig(bytes.NewReader(d))
case IMAGE_GIF:
imageInfo, imageErr = gif.DecodeConfig(bytes.NewReader(d))
default:
http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType)
return
}
elapsedProbe = time.Since(t).Microseconds()
t = time.Now()
if imageErr != nil {
http.Error(w, "Invalid Image Data", http.StatusBadRequest)
return
}
if imageInfo.Width > 8192 {
http.Error(w, "Image cannot be wider than 8192 pixels", http.StatusUnprocessableEntity)
return
}
if imageInfo.Height > 4096 {
http.Error(w, "Image cannot be taller than 4096 pixels", http.StatusUnprocessableEntity)
return
}
if imageInfo.Height < 32 || imageInfo.Width < 32 {
http.Error(w, "Image cannot be smaller than 32 pixels", http.StatusUnprocessableEntity)
return
}
switch imageType { // Decode Pixels
case IMAGE_WEBP:
imageData, imageErr = webp.Decode(bytes.NewReader(d))
case IMAGE_JPEG:
imageData, imageErr = jpeg.Decode(bytes.NewReader(d))
case IMAGE_PNG:
imageData, imageErr = png.Decode(bytes.NewReader(d))
case IMAGE_GIF:
imageData, imageErr = gif.Decode(bytes.NewReader(d))
default:
http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType)
return
}
if imageErr != nil {
http.Error(w, "Invalid Image Data", http.StatusBadRequest)
return
}
elapsedDecode = time.Since(t).Microseconds()
t = time.Now()
// Image Classification
results, err := ModelClassifyImage(imageData)
if err != nil {
http.Error(w, "Classify Error: "+err.Error(), http.StatusInternalServerError)
return
}
elapsedClassify = time.Since(t).Microseconds()
t = time.Now()
// Output Results
output, err := json.Marshal(map[string]any{
"allowed": results.Rating < MODEL_THRESHOLD,
"timings": map[string]any{
"probe": elapsedProbe,
"decode": elapsedDecode,
"classify": elapsedClassify,
"total": elapsedProbe + elapsedDecode + elapsedClassify,
},
"image": map[string]any{
"hash_sha256": fmt.Sprintf("%x", sha256.Sum256(d)),
"hash_md5": fmt.Sprintf("%x", md5.Sum(d)),
"height": imageInfo.Height,
"width": imageInfo.Width,
},
"logits": map[string]any{
"rating": math.Round(float64(results.Rating)*10000) / 100,
"drawing": math.Round(float64(results.Drawing)*10000) / 100,
"hentai": math.Round(float64(results.Hentai)*10000) / 100,
"neutral": math.Round(float64(results.Neutral)*10000) / 100,
"porn": math.Round(float64(results.Porn)*10000) / 100,
"sexy": math.Round(float64(results.Sexy)*10000) / 100,
},
})
if err != nil {
http.Error(w, "Encoding Error: "+err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(output)
fmt.Fprintf(os.Stdout, "%s\n", output)
// Other Methods
case http.MethodHead, http.MethodGet:
http.Error(w, "nsfw-service ; ><> .o( blub blub )", http.StatusOK)
default:
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
}
})
svr := http.Server{
Handler: mux,
Addr: HTTP_ADDRESS,
MaxHeaderBytes: 4096,
IdleTimeout: 10 * time.Second,
ReadHeaderTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
ReadTimeout: 30 * time.Second,
}
// Shutdown Logic
await.Add(1)
go func() {
defer await.Done()
<-stop.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := svr.Shutdown(shutdownCtx); err != nil {
log.Println("[HTTP] Shutdown Failed:", err)
}
log.Println("[HTTP] Server Closed")
}()
// Server Startup
log.Println("[HTTP] Listening @", HTTP_ADDRESS)
if err := svr.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalln("[HTTP] Startup Failed:", err)
}
}
func SetupModel(stop context.Context, await *sync.WaitGroup) {
t := time.Now()
// Initialize Environment
onnx.SetSharedLibraryPath(ONNX_RUNTIME_PATH)
if err := onnx.InitializeEnvironment(); err != nil {
log.Fatalf("[ONNX] Failed to initialize ONNX Runtime: %s\n", err)
return
}
// Initialize Settings
options, err := onnx.NewSessionOptions()
if err != nil {
log.Fatalf("[ONNX] Failed to create session options: %s\n", err)
return
}
defer options.Destroy()
if ONNX_RUNTIME_CUDA {
cudaOptions, err := onnx.NewCUDAProviderOptions()
if err != nil {
log.Printf("[ONNX] CUDA unavailable, falling back to CPU: %s\n", err)
} else {
defer cudaOptions.Destroy()
cudaOptions.Update(map[string]string{
"cudnn_conv_algo_search": "DEFAULT", // use the only working frontend
})
options.AppendExecutionProviderCUDA(cudaOptions)
}
}
// Initialize Model
session, err := onnx.NewDynamicAdvancedSessionWithONNXData(
ONNX_MODEL,
[]string{"input"},
[]string{"prediction"},
options,
)
if err != nil {
log.Printf("[ONNX] Failed to load model: %s\n", err)
return
}
ONNX_SESSION = session
// Test Model with Dummy Data
dummy := make([]float32, MODEL_SIZE*MODEL_SIZE*3)
if _, err := ModelClassifyTensorBatch(dummy, 1); err != nil {
log.Printf("[ONNX] Failed to initialize model: %s\n", err)
return
}
await.Add(1)
go func() {
defer await.Done()
<-stop.Done()
ONNX_SESSION.Destroy()
ONNX_SESSION = nil
onnx.DestroyEnvironment()
log.Println("[ONNX] Closed")
}()
log.Printf("[ONNX] Model ready in %s\n", time.Since(t))
}
// Cast Predictions on a Tensor using the NSFW Model
func ModelClassifyTensorBatch(data []float32, count int) ([]ClassifyResults, error) {
inputTensor, err := onnx.NewTensor(
onnx.NewShape(int64(count), MODEL_SIZE, MODEL_SIZE, 3),
data,
)
if err != nil {
return nil, err
}
defer inputTensor.Destroy()
outputs := []onnx.ArbitraryTensor{nil}
if err := ONNX_SESSION.Run([]onnx.ArbitraryTensor{inputTensor}, outputs); err != nil {
return nil, err
}
defer outputs[0].Destroy()
raw := outputs[0].(*onnx.Tensor[float32]).GetData()
results := make([]ClassifyResults, count)
for i := range results {
base := i * 5
results[i] = ClassifyResults{
Rating: (raw[base+1] + raw[base+3] + (raw[base+4] * 0.9)),
Drawing: raw[base+0],
Hentai: raw[base+1],
Neutral: raw[base+2],
Porn: raw[base+3],
Sexy: raw[base+4],
}
}
return results, nil
}
// Classify an Image returning true if it's considered safe
func ModelClassifyImage(someImage image.Image) (ClassifyResults, error) {
// Resize Image to Usable Size
resized := image.NewRGBA(image.Rect(0, 0, MODEL_SIZE, MODEL_SIZE))
draw.CatmullRom.Scale(resized, resized.Rect, someImage, someImage.Bounds(), draw.Over, nil)
// Convert Pixel Data into Normalized Floats
var tensorCap = MODEL_SIZE * MODEL_SIZE * 3
var tensorData = make([]float32, 0, tensorCap)
for y := 0; y < MODEL_SIZE; y++ {
for x := 0; x < MODEL_SIZE; x++ {
r, g, b, _ := resized.At(x, y).RGBA()
tensorData = append(tensorData, float32(r)/65535, float32(g)/65535, float32(b)/65535)
}
}
// Create Tensor, reshape it, then classify
results, err := ModelClassifyTensorBatch(tensorData, 1)
if err != nil {
return ClassifyResults{}, err
}
// Return Results
// Drawing[0], Hentai[1], Neutral[2], Porn[3], Sexy[4]
return results[0], nil
}