441 lines
11 KiB
Go
441 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/md5"
|
|
"crypto/sha256"
|
|
_ "embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"image"
|
|
"image/gif"
|
|
"image/jpeg"
|
|
"image/png"
|
|
"io"
|
|
"log"
|
|
"math"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
onnx "github.com/yalue/onnxruntime_go"
|
|
|
|
"golang.org/x/image/draw"
|
|
"golang.org/x/image/webp"
|
|
"golang.org/x/sync/semaphore"
|
|
)
|
|
|
|
type ImageType string
|
|
|
|
const (
|
|
IMAGE_OTHER ImageType = "UNKNOWN"
|
|
IMAGE_WEBP ImageType = "WEBP"
|
|
IMAGE_JPEG ImageType = "JPG"
|
|
IMAGE_PNG ImageType = "PNG"
|
|
IMAGE_GIF ImageType = "GIF"
|
|
MODEL_SIZE = 224
|
|
)
|
|
|
|
var (
|
|
//go:embed nsfw.onnx
|
|
ONNX_MODEL []byte
|
|
ONNX_SESSION *onnx.DynamicAdvancedSession
|
|
ONNX_RUNTIME_PATH = os.Getenv("ONNX_RUNTIME_PATH")
|
|
ONNX_RUNTIME_CUDA = os.Getenv("ONNX_RUNTIME_CUDA") != ""
|
|
HTTP_ADDRESS = os.Getenv("HTTP_ADDRESS")
|
|
HTTP_SEMAPHORE = semaphore.NewWeighted(16)
|
|
HTTP_MAX_BODY_BYTES int64 = 16 * 1024 * 1024 // 16mb
|
|
MODEL_THRESHOLD float32 = 0.7
|
|
)
|
|
|
|
func init() {
|
|
if HTTP_ADDRESS == "" {
|
|
HTTP_ADDRESS = "127.0.0.1:9000"
|
|
}
|
|
|
|
if plaintext := os.Getenv("HTTP_CONCURRENCY"); plaintext != "" {
|
|
val, err := strconv.ParseInt(plaintext, 10, 64)
|
|
if err == nil {
|
|
HTTP_SEMAPHORE = semaphore.NewWeighted(val)
|
|
}
|
|
}
|
|
|
|
if plaintext := os.Getenv("HTTP_MAX_BODY_BYTES"); plaintext != "" {
|
|
val, err := strconv.ParseInt(plaintext, 10, 64)
|
|
if err == nil {
|
|
HTTP_MAX_BODY_BYTES = val
|
|
}
|
|
}
|
|
|
|
if plaintext := os.Getenv("MODEL_THRESHOLD"); plaintext != "" {
|
|
val, err := strconv.ParseFloat(plaintext, 32)
|
|
if err == nil {
|
|
MODEL_THRESHOLD = float32(val)
|
|
}
|
|
}
|
|
}
|
|
|
|
type ClassifyResults struct {
|
|
Rating float32
|
|
Drawing float32
|
|
Hentai float32
|
|
Neutral float32
|
|
Porn float32
|
|
Sexy float32
|
|
}
|
|
|
|
func main() {
|
|
time.Local = time.UTC
|
|
|
|
// Startup Services
|
|
var stopCtx, stop = context.WithCancel(context.Background())
|
|
var stopWg sync.WaitGroup
|
|
|
|
SetupModel(stopCtx, &stopWg)
|
|
go SetupHTTP(stopCtx, &stopWg)
|
|
|
|
// Await Shutdown Signal
|
|
cancel := make(chan os.Signal, 1)
|
|
signal.Notify(cancel, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
|
<-cancel
|
|
stop()
|
|
|
|
// Begin Shutdown Process
|
|
timeout, finish := context.WithTimeout(context.Background(), time.Minute)
|
|
defer finish()
|
|
go func() {
|
|
<-timeout.Done()
|
|
if timeout.Err() == context.DeadlineExceeded {
|
|
log.Fatalln("[MAIN] Shutdown Deadline Exceeded")
|
|
}
|
|
}()
|
|
stopWg.Wait()
|
|
os.Exit(0)
|
|
}
|
|
|
|
func SetupHTTP(stop context.Context, await *sync.WaitGroup) {
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
var elapsedProbe, elapsedDecode, elapsedClassify int64
|
|
t := time.Now()
|
|
|
|
switch r.Method {
|
|
|
|
case http.MethodPost:
|
|
// Request Validation
|
|
if r.ContentLength > HTTP_MAX_BODY_BYTES {
|
|
http.Error(w, "Content Too Large", http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
r.Body = http.MaxBytesReader(w, r.Body, HTTP_MAX_BODY_BYTES)
|
|
|
|
var d []byte
|
|
if raw, err := io.ReadAll(r.Body); err != nil {
|
|
http.Error(w, "Body Error", http.StatusLengthRequired)
|
|
return
|
|
} else {
|
|
d = raw
|
|
}
|
|
|
|
// Image Validation
|
|
var imageType ImageType
|
|
var imageInfo image.Config
|
|
var imageData image.Image
|
|
var imageErr error
|
|
|
|
switch { // Sniffing
|
|
case len(d) > 3 && // JPEG
|
|
d[0] == 0xFF && d[1] == 0xD8 && d[2] == 0xFF:
|
|
imageType = IMAGE_JPEG
|
|
|
|
case len(d) > 8 && // PNG
|
|
d[0] == 0x89 && d[1] == 0x50 && d[2] == 0x4E && d[3] == 0x47 &&
|
|
d[4] == 0x0D && d[5] == 0x0A && d[6] == 0x1A && d[7] == 0x0A:
|
|
imageType = IMAGE_PNG
|
|
|
|
case len(d) > 4 && // GIF
|
|
d[0] == 0x47 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x38:
|
|
imageType = IMAGE_GIF
|
|
|
|
case len(d) > 12 && // WEBP
|
|
d[0] == 0x52 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x46 &&
|
|
d[8] == 0x57 && d[9] == 0x45 && d[10] == 0x42 && d[11] == 0x50:
|
|
imageType = IMAGE_WEBP
|
|
|
|
default:
|
|
imageType = IMAGE_OTHER
|
|
}
|
|
|
|
switch imageType { // Decode Header
|
|
case IMAGE_WEBP:
|
|
imageInfo, imageErr = webp.DecodeConfig(bytes.NewReader(d))
|
|
case IMAGE_JPEG:
|
|
imageInfo, imageErr = jpeg.DecodeConfig(bytes.NewReader(d))
|
|
case IMAGE_PNG:
|
|
imageInfo, imageErr = png.DecodeConfig(bytes.NewReader(d))
|
|
case IMAGE_GIF:
|
|
imageInfo, imageErr = gif.DecodeConfig(bytes.NewReader(d))
|
|
default:
|
|
http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType)
|
|
return
|
|
}
|
|
elapsedProbe = time.Since(t).Microseconds()
|
|
t = time.Now()
|
|
|
|
if imageErr != nil {
|
|
http.Error(w, "Invalid Image Data", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if imageInfo.Width > 8192 {
|
|
http.Error(w, "Image cannot be wider than 8192 pixels", http.StatusUnprocessableEntity)
|
|
return
|
|
}
|
|
if imageInfo.Height > 4096 {
|
|
http.Error(w, "Image cannot be taller than 4096 pixels", http.StatusUnprocessableEntity)
|
|
return
|
|
}
|
|
if imageInfo.Height < 32 || imageInfo.Width < 32 {
|
|
http.Error(w, "Image cannot be smaller than 32 pixels", http.StatusUnprocessableEntity)
|
|
return
|
|
}
|
|
|
|
switch imageType { // Decode Pixels
|
|
case IMAGE_WEBP:
|
|
imageData, imageErr = webp.Decode(bytes.NewReader(d))
|
|
case IMAGE_JPEG:
|
|
imageData, imageErr = jpeg.Decode(bytes.NewReader(d))
|
|
case IMAGE_PNG:
|
|
imageData, imageErr = png.Decode(bytes.NewReader(d))
|
|
case IMAGE_GIF:
|
|
imageData, imageErr = gif.Decode(bytes.NewReader(d))
|
|
default:
|
|
http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType)
|
|
return
|
|
}
|
|
if imageErr != nil {
|
|
http.Error(w, "Invalid Image Data", http.StatusBadRequest)
|
|
return
|
|
}
|
|
elapsedDecode = time.Since(t).Microseconds()
|
|
t = time.Now()
|
|
|
|
// Image Classification
|
|
results, err := ModelClassifyImage(imageData)
|
|
if err != nil {
|
|
http.Error(w, "Classify Error: "+err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
elapsedClassify = time.Since(t).Microseconds()
|
|
t = time.Now()
|
|
|
|
// Output Results
|
|
output, err := json.Marshal(map[string]any{
|
|
"allowed": results.Rating < MODEL_THRESHOLD,
|
|
"timings": map[string]any{
|
|
"probe": elapsedProbe,
|
|
"decode": elapsedDecode,
|
|
"classify": elapsedClassify,
|
|
"total": elapsedProbe + elapsedDecode + elapsedClassify,
|
|
},
|
|
"image": map[string]any{
|
|
"hash_sha256": fmt.Sprintf("%x", sha256.Sum256(d)),
|
|
"hash_md5": fmt.Sprintf("%x", md5.Sum(d)),
|
|
"height": imageInfo.Height,
|
|
"width": imageInfo.Width,
|
|
},
|
|
"logits": map[string]any{
|
|
"rating": math.Round(float64(results.Rating)*10000) / 100,
|
|
"drawing": math.Round(float64(results.Drawing)*10000) / 100,
|
|
"hentai": math.Round(float64(results.Hentai)*10000) / 100,
|
|
"neutral": math.Round(float64(results.Neutral)*10000) / 100,
|
|
"porn": math.Round(float64(results.Porn)*10000) / 100,
|
|
"sexy": math.Round(float64(results.Sexy)*10000) / 100,
|
|
},
|
|
})
|
|
if err != nil {
|
|
http.Error(w, "Encoding Error: "+err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.Write(output)
|
|
fmt.Fprintf(os.Stdout, "%s\n", output)
|
|
|
|
// Other Methods
|
|
case http.MethodHead, http.MethodGet:
|
|
http.Error(w, "nsfw-service ; ><> .o( blub blub )", http.StatusOK)
|
|
|
|
default:
|
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
|
|
})
|
|
|
|
svr := http.Server{
|
|
Handler: mux,
|
|
Addr: HTTP_ADDRESS,
|
|
MaxHeaderBytes: 4096,
|
|
IdleTimeout: 10 * time.Second,
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
ReadTimeout: 30 * time.Second,
|
|
}
|
|
|
|
// Shutdown Logic
|
|
await.Add(1)
|
|
go func() {
|
|
defer await.Done()
|
|
<-stop.Done()
|
|
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
if err := svr.Shutdown(shutdownCtx); err != nil {
|
|
log.Println("[HTTP] Shutdown Failed:", err)
|
|
}
|
|
|
|
log.Println("[HTTP] Server Closed")
|
|
}()
|
|
|
|
// Server Startup
|
|
log.Println("[HTTP] Listening @", HTTP_ADDRESS)
|
|
if err := svr.ListenAndServe(); err != http.ErrServerClosed {
|
|
log.Fatalln("[HTTP] Startup Failed:", err)
|
|
}
|
|
}
|
|
|
|
func SetupModel(stop context.Context, await *sync.WaitGroup) {
|
|
t := time.Now()
|
|
|
|
// Initialize Environment
|
|
onnx.SetSharedLibraryPath(ONNX_RUNTIME_PATH)
|
|
if err := onnx.InitializeEnvironment(); err != nil {
|
|
log.Fatalf("[ONNX] Failed to initialize ONNX Runtime: %s\n", err)
|
|
return
|
|
}
|
|
|
|
// Initialize Settings
|
|
options, err := onnx.NewSessionOptions()
|
|
if err != nil {
|
|
log.Fatalf("[ONNX] Failed to create session options: %s\n", err)
|
|
return
|
|
}
|
|
defer options.Destroy()
|
|
|
|
if ONNX_RUNTIME_CUDA {
|
|
cudaOptions, err := onnx.NewCUDAProviderOptions()
|
|
if err != nil {
|
|
log.Printf("[ONNX] CUDA unavailable, falling back to CPU: %s\n", err)
|
|
} else {
|
|
defer cudaOptions.Destroy()
|
|
cudaOptions.Update(map[string]string{
|
|
"cudnn_conv_algo_search": "DEFAULT", // use the only working frontend
|
|
})
|
|
options.AppendExecutionProviderCUDA(cudaOptions)
|
|
}
|
|
}
|
|
|
|
// Initialize Model
|
|
session, err := onnx.NewDynamicAdvancedSessionWithONNXData(
|
|
ONNX_MODEL,
|
|
[]string{"input"},
|
|
[]string{"prediction"},
|
|
options,
|
|
)
|
|
if err != nil {
|
|
log.Printf("[ONNX] Failed to load model: %s\n", err)
|
|
return
|
|
}
|
|
ONNX_SESSION = session
|
|
|
|
// Test Model with Dummy Data
|
|
dummy := make([]float32, MODEL_SIZE*MODEL_SIZE*3)
|
|
if _, err := ModelClassifyTensorBatch(dummy, 1); err != nil {
|
|
log.Printf("[ONNX] Failed to initialize model: %s\n", err)
|
|
return
|
|
}
|
|
|
|
await.Add(1)
|
|
go func() {
|
|
defer await.Done()
|
|
<-stop.Done()
|
|
ONNX_SESSION.Destroy()
|
|
ONNX_SESSION = nil
|
|
onnx.DestroyEnvironment()
|
|
log.Println("[ONNX] Closed")
|
|
}()
|
|
|
|
log.Printf("[ONNX] Model ready in %s\n", time.Since(t))
|
|
}
|
|
|
|
// Cast Predictions on a Tensor using the NSFW Model
|
|
func ModelClassifyTensorBatch(data []float32, count int) ([]ClassifyResults, error) {
|
|
|
|
inputTensor, err := onnx.NewTensor(
|
|
onnx.NewShape(int64(count), MODEL_SIZE, MODEL_SIZE, 3),
|
|
data,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer inputTensor.Destroy()
|
|
|
|
outputs := []onnx.ArbitraryTensor{nil}
|
|
if err := ONNX_SESSION.Run([]onnx.ArbitraryTensor{inputTensor}, outputs); err != nil {
|
|
return nil, err
|
|
}
|
|
defer outputs[0].Destroy()
|
|
|
|
raw := outputs[0].(*onnx.Tensor[float32]).GetData()
|
|
results := make([]ClassifyResults, count)
|
|
|
|
for i := range results {
|
|
base := i * 5
|
|
results[i] = ClassifyResults{
|
|
Rating: (raw[base+1] + raw[base+3] + (raw[base+4] * 0.9)),
|
|
Drawing: raw[base+0],
|
|
Hentai: raw[base+1],
|
|
Neutral: raw[base+2],
|
|
Porn: raw[base+3],
|
|
Sexy: raw[base+4],
|
|
}
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// Classify an Image returning true if it's considered safe
|
|
func ModelClassifyImage(someImage image.Image) (ClassifyResults, error) {
|
|
|
|
// Resize Image to Usable Size
|
|
resized := image.NewRGBA(image.Rect(0, 0, MODEL_SIZE, MODEL_SIZE))
|
|
draw.CatmullRom.Scale(resized, resized.Rect, someImage, someImage.Bounds(), draw.Over, nil)
|
|
|
|
// Convert Pixel Data into Normalized Floats
|
|
var tensorCap = MODEL_SIZE * MODEL_SIZE * 3
|
|
var tensorData = make([]float32, 0, tensorCap)
|
|
for y := 0; y < MODEL_SIZE; y++ {
|
|
for x := 0; x < MODEL_SIZE; x++ {
|
|
r, g, b, _ := resized.At(x, y).RGBA()
|
|
tensorData = append(tensorData, float32(r)/65535, float32(g)/65535, float32(b)/65535)
|
|
}
|
|
}
|
|
|
|
// Create Tensor, reshape it, then classify
|
|
results, err := ModelClassifyTensorBatch(tensorData, 1)
|
|
if err != nil {
|
|
return ClassifyResults{}, err
|
|
}
|
|
|
|
// Return Results
|
|
// Drawing[0], Hentai[1], Neutral[2], Porn[3], Sexy[4]
|
|
return results[0], nil
|
|
}
|