package main import ( "bytes" "context" "crypto/md5" "crypto/sha256" _ "embed" "encoding/json" "fmt" "image" "image/gif" "image/jpeg" "image/png" "io" "log" "math" "net/http" "os" "os/signal" "strconv" "sync" "syscall" "time" onnx "github.com/yalue/onnxruntime_go" "golang.org/x/image/draw" "golang.org/x/image/webp" "golang.org/x/sync/semaphore" ) type ImageType string const ( IMAGE_OTHER ImageType = "UNKNOWN" IMAGE_WEBP ImageType = "WEBP" IMAGE_JPEG ImageType = "JPG" IMAGE_PNG ImageType = "PNG" IMAGE_GIF ImageType = "GIF" MODEL_SIZE = 224 ) var ( //go:embed nsfw.onnx ONNX_MODEL []byte ONNX_SESSION *onnx.DynamicAdvancedSession ONNX_RUNTIME_PATH = os.Getenv("ONNX_RUNTIME_PATH") ONNX_RUNTIME_CUDA = os.Getenv("ONNX_RUNTIME_CUDA") != "" HTTP_ADDRESS = os.Getenv("HTTP_ADDRESS") HTTP_SEMAPHORE = semaphore.NewWeighted(16) HTTP_MAX_BODY_BYTES int64 = 16 * 1024 * 1024 // 16mb MODEL_THRESHOLD float32 = 0.7 ) func init() { if HTTP_ADDRESS == "" { HTTP_ADDRESS = "127.0.0.1:9000" } if plaintext := os.Getenv("HTTP_CONCURRENCY"); plaintext != "" { val, err := strconv.ParseInt(plaintext, 10, 64) if err == nil { HTTP_SEMAPHORE = semaphore.NewWeighted(val) } } if plaintext := os.Getenv("HTTP_MAX_BODY_BYTES"); plaintext != "" { val, err := strconv.ParseInt(plaintext, 10, 64) if err == nil { HTTP_MAX_BODY_BYTES = val } } if plaintext := os.Getenv("MODEL_THRESHOLD"); plaintext != "" { val, err := strconv.ParseFloat(plaintext, 32) if err == nil { MODEL_THRESHOLD = float32(val) } } } type ClassifyResults struct { Rating float32 Drawing float32 Hentai float32 Neutral float32 Porn float32 Sexy float32 } func main() { time.Local = time.UTC // Startup Services var stopCtx, stop = context.WithCancel(context.Background()) var stopWg sync.WaitGroup SetupModel(stopCtx, &stopWg) go SetupHTTP(stopCtx, &stopWg) // Await Shutdown Signal cancel := make(chan os.Signal, 1) signal.Notify(cancel, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) <-cancel stop() // Begin Shutdown Process timeout, finish := context.WithTimeout(context.Background(), time.Minute) defer finish() go func() { <-timeout.Done() if timeout.Err() == context.DeadlineExceeded { log.Fatalln("[MAIN] Shutdown Deadline Exceeded") } }() stopWg.Wait() os.Exit(0) } func SetupHTTP(stop context.Context, await *sync.WaitGroup) { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { var elapsedProbe, elapsedDecode, elapsedClassify int64 t := time.Now() switch r.Method { case http.MethodPost: // Request Validation if r.ContentLength > HTTP_MAX_BODY_BYTES { http.Error(w, "Content Too Large", http.StatusRequestEntityTooLarge) return } r.Body = http.MaxBytesReader(w, r.Body, HTTP_MAX_BODY_BYTES) var d []byte if raw, err := io.ReadAll(r.Body); err != nil { http.Error(w, "Body Error", http.StatusLengthRequired) return } else { d = raw } // Image Validation var imageType ImageType var imageInfo image.Config var imageData image.Image var imageErr error switch { // Sniffing case len(d) > 3 && // JPEG d[0] == 0xFF && d[1] == 0xD8 && d[2] == 0xFF: imageType = IMAGE_JPEG case len(d) > 8 && // PNG d[0] == 0x89 && d[1] == 0x50 && d[2] == 0x4E && d[3] == 0x47 && d[4] == 0x0D && d[5] == 0x0A && d[6] == 0x1A && d[7] == 0x0A: imageType = IMAGE_PNG case len(d) > 4 && // GIF d[0] == 0x47 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x38: imageType = IMAGE_GIF case len(d) > 12 && // WEBP d[0] == 0x52 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x46 && d[8] == 0x57 && d[9] == 0x45 && d[10] == 0x42 && d[11] == 0x50: imageType = IMAGE_WEBP default: imageType = IMAGE_OTHER } switch imageType { // Decode Header case IMAGE_WEBP: imageInfo, imageErr = webp.DecodeConfig(bytes.NewReader(d)) case IMAGE_JPEG: imageInfo, imageErr = jpeg.DecodeConfig(bytes.NewReader(d)) case IMAGE_PNG: imageInfo, imageErr = png.DecodeConfig(bytes.NewReader(d)) case IMAGE_GIF: imageInfo, imageErr = gif.DecodeConfig(bytes.NewReader(d)) default: http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType) return } elapsedProbe = time.Since(t).Microseconds() t = time.Now() if imageErr != nil { http.Error(w, "Invalid Image Data", http.StatusBadRequest) return } if imageInfo.Width > 8192 { http.Error(w, "Image cannot be wider than 8192 pixels", http.StatusUnprocessableEntity) return } if imageInfo.Height > 4096 { http.Error(w, "Image cannot be taller than 4096 pixels", http.StatusUnprocessableEntity) return } if imageInfo.Height < 32 || imageInfo.Width < 32 { http.Error(w, "Image cannot be smaller than 32 pixels", http.StatusUnprocessableEntity) return } switch imageType { // Decode Pixels case IMAGE_WEBP: imageData, imageErr = webp.Decode(bytes.NewReader(d)) case IMAGE_JPEG: imageData, imageErr = jpeg.Decode(bytes.NewReader(d)) case IMAGE_PNG: imageData, imageErr = png.Decode(bytes.NewReader(d)) case IMAGE_GIF: imageData, imageErr = gif.Decode(bytes.NewReader(d)) default: http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType) return } if imageErr != nil { http.Error(w, "Invalid Image Data", http.StatusBadRequest) return } elapsedDecode = time.Since(t).Microseconds() t = time.Now() // Image Classification results, err := ModelClassifyImage(imageData) if err != nil { http.Error(w, "Classify Error: "+err.Error(), http.StatusInternalServerError) return } elapsedClassify = time.Since(t).Microseconds() t = time.Now() // Output Results output, err := json.Marshal(map[string]any{ "allowed": results.Rating < MODEL_THRESHOLD, "timings": map[string]any{ "probe": elapsedProbe, "decode": elapsedDecode, "classify": elapsedClassify, "total": elapsedProbe + elapsedDecode + elapsedClassify, }, "image": map[string]any{ "hash_sha256": fmt.Sprintf("%x", sha256.Sum256(d)), "hash_md5": fmt.Sprintf("%x", md5.Sum(d)), "height": imageInfo.Height, "width": imageInfo.Width, }, "logits": map[string]any{ "rating": math.Round(float64(results.Rating)*10000) / 100, "drawing": math.Round(float64(results.Drawing)*10000) / 100, "hentai": math.Round(float64(results.Hentai)*10000) / 100, "neutral": math.Round(float64(results.Neutral)*10000) / 100, "porn": math.Round(float64(results.Porn)*10000) / 100, "sexy": math.Round(float64(results.Sexy)*10000) / 100, }, }) if err != nil { http.Error(w, "Encoding Error: "+err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write(output) fmt.Fprintf(os.Stdout, "%s\n", output) // Other Methods case http.MethodHead, http.MethodGet: http.Error(w, "nsfw-service ; ><> .o( blub blub )", http.StatusOK) default: http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) } }) svr := http.Server{ Handler: mux, Addr: HTTP_ADDRESS, MaxHeaderBytes: 4096, IdleTimeout: 10 * time.Second, ReadHeaderTimeout: 10 * time.Second, WriteTimeout: 30 * time.Second, ReadTimeout: 30 * time.Second, } // Shutdown Logic await.Add(1) go func() { defer await.Done() <-stop.Done() shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() if err := svr.Shutdown(shutdownCtx); err != nil { log.Println("[HTTP] Shutdown Failed:", err) } log.Println("[HTTP] Server Closed") }() // Server Startup log.Println("[HTTP] Listening @", HTTP_ADDRESS) if err := svr.ListenAndServe(); err != http.ErrServerClosed { log.Fatalln("[HTTP] Startup Failed:", err) } } func SetupModel(stop context.Context, await *sync.WaitGroup) { t := time.Now() // Initialize Environment onnx.SetSharedLibraryPath(ONNX_RUNTIME_PATH) if err := onnx.InitializeEnvironment(); err != nil { log.Fatalf("[ONNX] Failed to initialize ONNX Runtime: %s\n", err) return } // Initialize Settings options, err := onnx.NewSessionOptions() if err != nil { log.Fatalf("[ONNX] Failed to create session options: %s\n", err) return } defer options.Destroy() if ONNX_RUNTIME_CUDA { cudaOptions, err := onnx.NewCUDAProviderOptions() if err != nil { log.Printf("[ONNX] CUDA unavailable, falling back to CPU: %s\n", err) } else { defer cudaOptions.Destroy() cudaOptions.Update(map[string]string{ "cudnn_conv_algo_search": "DEFAULT", // use the only working frontend }) options.AppendExecutionProviderCUDA(cudaOptions) } } // Initialize Model session, err := onnx.NewDynamicAdvancedSessionWithONNXData( ONNX_MODEL, []string{"input"}, []string{"prediction"}, options, ) if err != nil { log.Printf("[ONNX] Failed to load model: %s\n", err) return } ONNX_SESSION = session // Test Model with Dummy Data dummy := make([]float32, MODEL_SIZE*MODEL_SIZE*3) if _, err := ModelClassifyTensorBatch(dummy, 1); err != nil { log.Printf("[ONNX] Failed to initialize model: %s\n", err) return } await.Add(1) go func() { defer await.Done() <-stop.Done() ONNX_SESSION.Destroy() ONNX_SESSION = nil onnx.DestroyEnvironment() log.Println("[ONNX] Closed") }() log.Printf("[ONNX] Model ready in %s\n", time.Since(t)) } // Cast Predictions on a Tensor using the NSFW Model func ModelClassifyTensorBatch(data []float32, count int) ([]ClassifyResults, error) { inputTensor, err := onnx.NewTensor( onnx.NewShape(int64(count), MODEL_SIZE, MODEL_SIZE, 3), data, ) if err != nil { return nil, err } defer inputTensor.Destroy() outputs := []onnx.ArbitraryTensor{nil} if err := ONNX_SESSION.Run([]onnx.ArbitraryTensor{inputTensor}, outputs); err != nil { return nil, err } defer outputs[0].Destroy() raw := outputs[0].(*onnx.Tensor[float32]).GetData() results := make([]ClassifyResults, count) for i := range results { base := i * 5 results[i] = ClassifyResults{ Rating: (raw[base+1] + raw[base+3] + (raw[base+4] * 0.9)), Drawing: raw[base+0], Hentai: raw[base+1], Neutral: raw[base+2], Porn: raw[base+3], Sexy: raw[base+4], } } return results, nil } // Classify an Image returning true if it's considered safe func ModelClassifyImage(someImage image.Image) (ClassifyResults, error) { // Resize Image to Usable Size resized := image.NewRGBA(image.Rect(0, 0, MODEL_SIZE, MODEL_SIZE)) draw.CatmullRom.Scale(resized, resized.Rect, someImage, someImage.Bounds(), draw.Over, nil) // Convert Pixel Data into Normalized Floats var tensorCap = MODEL_SIZE * MODEL_SIZE * 3 var tensorData = make([]float32, 0, tensorCap) for y := 0; y < MODEL_SIZE; y++ { for x := 0; x < MODEL_SIZE; x++ { r, g, b, _ := resized.At(x, y).RGBA() tensorData = append(tensorData, float32(r)/65535, float32(g)/65535, float32(b)/65535) } } // Create Tensor, reshape it, then classify results, err := ModelClassifyTensorBatch(tensorData, 1) if err != nil { return ClassifyResults{}, err } // Return Results // Drawing[0], Hentai[1], Neutral[2], Porn[3], Sexy[4] return results[0], nil }