commit 54e9ca5f34f737e8d5d6a7e9210669b2a55554fd Author: bakonpancakz Date: Sat May 23 17:22:03 2026 -0700 Initial Release diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..15ee2f1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 bakonpancakz + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README b/README new file mode 100644 index 0000000..1d45a6b --- /dev/null +++ b/README @@ -0,0 +1,63 @@ +---------------- + nsfw-service +---------------- + +[ Usage ] + +Send an HTTP POST request to port 9000 with the request body containing raw image data. + +Supported formats: WEBP, PNG, GIF, JPEG. + +Animated WEBPs and PNGs are not supported and will return a 415 status code. +For GIFs, only the first frame is used. + +--- Response Status Codes --- + +400 – Invalid image data +405 – Method not allowed (Use POST for inference and HEAD or GET for health) +411 – Failed to read request body +413 – Request body too large +415 – Unsupported image type +422 – Invalid image dimensions (>8192x4096px or <32px) +500 – Server error, response body is plaintext + +--- Example Response Body --- + +{ + "allowed": true, + "image": { + "hash_md5": "50f64bdb0f11d281505bce990e805569", + "hash_sha256": "f2a94429ccd5e5467f6a1f2bd166d8def75ced242a59b6f71b659e827c008b75", + "height": 850, + "width": 850 + }, + "logits": { + "drawing": 0.85562634, + "hentai": 0.14418653, + "neutral": 0.000010294689, + "porn": 0.00014420893, + "sexy": 0.00003254598 + }, + "timings": { + "classify": 7125, + "decode": 8746, + "probe": 0, + "total": 15871 + } +} + +[ Config ] + +Configure the service using environment variables: + +| Name | Default | Description +| HTTP_ADDRESS | 127.0.0.1:9000 | Address and Port to use for requests +| ONNX_RUNTIME_PATH | | Set to anything to enable CUDA, requires you supply your own ONNX runtime +| ONNX_RUNTIME_PATH | | Path to the ONNX runtime library, either a onnxruntime.dll or libonnxruntime.so +| HTTP_CONCURRENCY | 16 | Maximum concurrent api requests +| HTTP_MAX_BODY_BYTES | 16777216 | Maximum request size in bytes, default is 16MB +| MODEL_THRESHOLD | 0.7 | Threshold before an image is consider inappropriate + +[ Credits ] + +Uses the following AI model: https://github.com/GantMan/nsfw_model \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f9efe97 --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module service-nsfw + +go 1.26 + +require ( + github.com/yalue/onnxruntime_go v1.28.0 + golang.org/x/image v0.39.0 + golang.org/x/sync v0.20.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9a4d61d --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/yalue/onnxruntime_go v1.28.0 h1:ximEqgLtBhb3DY0IHyR0GWGpGJ+xef85qxgPwa/iotg= +github.com/yalue/onnxruntime_go v1.28.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= +golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww= +golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= diff --git a/main.go b/main.go new file mode 100644 index 0000000..80d559b --- /dev/null +++ b/main.go @@ -0,0 +1,440 @@ +package main + +import ( + "bytes" + "context" + "crypto/md5" + "crypto/sha256" + _ "embed" + "encoding/json" + "fmt" + "image" + "image/gif" + "image/jpeg" + "image/png" + "io" + "log" + "math" + "net/http" + "os" + "os/signal" + "strconv" + "sync" + "syscall" + "time" + + onnx "github.com/yalue/onnxruntime_go" + + "golang.org/x/image/draw" + "golang.org/x/image/webp" + "golang.org/x/sync/semaphore" +) + +type ImageType string + +const ( + IMAGE_OTHER ImageType = "UNKNOWN" + IMAGE_WEBP ImageType = "WEBP" + IMAGE_JPEG ImageType = "JPG" + IMAGE_PNG ImageType = "PNG" + IMAGE_GIF ImageType = "GIF" + MODEL_SIZE = 224 +) + +var ( + //go:embed nsfw.onnx + ONNX_MODEL []byte + ONNX_SESSION *onnx.DynamicAdvancedSession + ONNX_RUNTIME_PATH = os.Getenv("ONNX_RUNTIME_PATH") + ONNX_RUNTIME_CUDA = os.Getenv("ONNX_RUNTIME_CUDA") != "" + HTTP_ADDRESS = os.Getenv("HTTP_ADDRESS") + HTTP_SEMAPHORE = semaphore.NewWeighted(16) + HTTP_MAX_BODY_BYTES int64 = 16 * 1024 * 1024 // 16mb + MODEL_THRESHOLD float32 = 0.7 +) + +func init() { + if HTTP_ADDRESS == "" { + HTTP_ADDRESS = "127.0.0.1:9000" + } + + if plaintext := os.Getenv("HTTP_CONCURRENCY"); plaintext != "" { + val, err := strconv.ParseInt(plaintext, 10, 64) + if err == nil { + HTTP_SEMAPHORE = semaphore.NewWeighted(val) + } + } + + if plaintext := os.Getenv("HTTP_MAX_BODY_BYTES"); plaintext != "" { + val, err := strconv.ParseInt(plaintext, 10, 64) + if err == nil { + HTTP_MAX_BODY_BYTES = val + } + } + + if plaintext := os.Getenv("MODEL_THRESHOLD"); plaintext != "" { + val, err := strconv.ParseFloat(plaintext, 32) + if err == nil { + MODEL_THRESHOLD = float32(val) + } + } +} + +type ClassifyResults struct { + Rating float32 + Drawing float32 + Hentai float32 + Neutral float32 + Porn float32 + Sexy float32 +} + +func main() { + time.Local = time.UTC + + // Startup Services + var stopCtx, stop = context.WithCancel(context.Background()) + var stopWg sync.WaitGroup + + SetupModel(stopCtx, &stopWg) + go SetupHTTP(stopCtx, &stopWg) + + // Await Shutdown Signal + cancel := make(chan os.Signal, 1) + signal.Notify(cancel, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + <-cancel + stop() + + // Begin Shutdown Process + timeout, finish := context.WithTimeout(context.Background(), time.Minute) + defer finish() + go func() { + <-timeout.Done() + if timeout.Err() == context.DeadlineExceeded { + log.Fatalln("[MAIN] Shutdown Deadline Exceeded") + } + }() + stopWg.Wait() + os.Exit(0) +} + +func SetupHTTP(stop context.Context, await *sync.WaitGroup) { + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + var elapsedProbe, elapsedDecode, elapsedClassify int64 + t := time.Now() + + switch r.Method { + + case http.MethodPost: + // Request Validation + if r.ContentLength > HTTP_MAX_BODY_BYTES { + http.Error(w, "Content Too Large", http.StatusRequestEntityTooLarge) + return + } + r.Body = http.MaxBytesReader(w, r.Body, HTTP_MAX_BODY_BYTES) + + var d []byte + if raw, err := io.ReadAll(r.Body); err != nil { + http.Error(w, "Body Error", http.StatusLengthRequired) + return + } else { + d = raw + } + + // Image Validation + var imageType ImageType + var imageInfo image.Config + var imageData image.Image + var imageErr error + + switch { // Sniffing + case len(d) > 3 && // JPEG + d[0] == 0xFF && d[1] == 0xD8 && d[2] == 0xFF: + imageType = IMAGE_JPEG + + case len(d) > 8 && // PNG + d[0] == 0x89 && d[1] == 0x50 && d[2] == 0x4E && d[3] == 0x47 && + d[4] == 0x0D && d[5] == 0x0A && d[6] == 0x1A && d[7] == 0x0A: + imageType = IMAGE_PNG + + case len(d) > 4 && // GIF + d[0] == 0x47 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x38: + imageType = IMAGE_GIF + + case len(d) > 12 && // WEBP + d[0] == 0x52 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x46 && + d[8] == 0x57 && d[9] == 0x45 && d[10] == 0x42 && d[11] == 0x50: + imageType = IMAGE_WEBP + + default: + imageType = IMAGE_OTHER + } + + switch imageType { // Decode Header + case IMAGE_WEBP: + imageInfo, imageErr = webp.DecodeConfig(bytes.NewReader(d)) + case IMAGE_JPEG: + imageInfo, imageErr = jpeg.DecodeConfig(bytes.NewReader(d)) + case IMAGE_PNG: + imageInfo, imageErr = png.DecodeConfig(bytes.NewReader(d)) + case IMAGE_GIF: + imageInfo, imageErr = gif.DecodeConfig(bytes.NewReader(d)) + default: + http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType) + return + } + elapsedProbe = time.Since(t).Microseconds() + t = time.Now() + + if imageErr != nil { + http.Error(w, "Invalid Image Data", http.StatusBadRequest) + return + } + + if imageInfo.Width > 8192 { + http.Error(w, "Image cannot be wider than 8192 pixels", http.StatusUnprocessableEntity) + return + } + if imageInfo.Height > 4096 { + http.Error(w, "Image cannot be taller than 4096 pixels", http.StatusUnprocessableEntity) + return + } + if imageInfo.Height < 32 || imageInfo.Width < 32 { + http.Error(w, "Image cannot be smaller than 32 pixels", http.StatusUnprocessableEntity) + return + } + + switch imageType { // Decode Pixels + case IMAGE_WEBP: + imageData, imageErr = webp.Decode(bytes.NewReader(d)) + case IMAGE_JPEG: + imageData, imageErr = jpeg.Decode(bytes.NewReader(d)) + case IMAGE_PNG: + imageData, imageErr = png.Decode(bytes.NewReader(d)) + case IMAGE_GIF: + imageData, imageErr = gif.Decode(bytes.NewReader(d)) + default: + http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType) + return + } + if imageErr != nil { + http.Error(w, "Invalid Image Data", http.StatusBadRequest) + return + } + elapsedDecode = time.Since(t).Microseconds() + t = time.Now() + + // Image Classification + results, err := ModelClassifyImage(imageData) + if err != nil { + http.Error(w, "Classify Error: "+err.Error(), http.StatusInternalServerError) + return + } + elapsedClassify = time.Since(t).Microseconds() + t = time.Now() + + // Output Results + output, err := json.Marshal(map[string]any{ + "allowed": results.Rating < MODEL_THRESHOLD, + "timings": map[string]any{ + "probe": elapsedProbe, + "decode": elapsedDecode, + "classify": elapsedClassify, + "total": elapsedProbe + elapsedDecode + elapsedClassify, + }, + "image": map[string]any{ + "hash_sha256": fmt.Sprintf("%x", sha256.Sum256(d)), + "hash_md5": fmt.Sprintf("%x", md5.Sum(d)), + "height": imageInfo.Height, + "width": imageInfo.Width, + }, + "logits": map[string]any{ + "rating": math.Round(float64(results.Rating)*10000) / 100, + "drawing": math.Round(float64(results.Drawing)*10000) / 100, + "hentai": math.Round(float64(results.Hentai)*10000) / 100, + "neutral": math.Round(float64(results.Neutral)*10000) / 100, + "porn": math.Round(float64(results.Porn)*10000) / 100, + "sexy": math.Round(float64(results.Sexy)*10000) / 100, + }, + }) + if err != nil { + http.Error(w, "Encoding Error: "+err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(output) + fmt.Fprintf(os.Stdout, "%s\n", output) + + // Other Methods + case http.MethodHead, http.MethodGet: + http.Error(w, "nsfw-service ; ><> .o( blub blub )", http.StatusOK) + + default: + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + } + + }) + + svr := http.Server{ + Handler: mux, + Addr: HTTP_ADDRESS, + MaxHeaderBytes: 4096, + IdleTimeout: 10 * time.Second, + ReadHeaderTimeout: 10 * time.Second, + WriteTimeout: 30 * time.Second, + ReadTimeout: 30 * time.Second, + } + + // Shutdown Logic + await.Add(1) + go func() { + defer await.Done() + <-stop.Done() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := svr.Shutdown(shutdownCtx); err != nil { + log.Println("[HTTP] Shutdown Failed:", err) + } + + log.Println("[HTTP] Server Closed") + }() + + // Server Startup + log.Println("[HTTP] Listening @", HTTP_ADDRESS) + if err := svr.ListenAndServe(); err != http.ErrServerClosed { + log.Fatalln("[HTTP] Startup Failed:", err) + } +} + +func SetupModel(stop context.Context, await *sync.WaitGroup) { + t := time.Now() + + // Initialize Environment + onnx.SetSharedLibraryPath(ONNX_RUNTIME_PATH) + if err := onnx.InitializeEnvironment(); err != nil { + log.Fatalf("[ONNX] Failed to initialize ONNX Runtime: %s\n", err) + return + } + + // Initialize Settings + options, err := onnx.NewSessionOptions() + if err != nil { + log.Fatalf("[ONNX] Failed to create session options: %s\n", err) + return + } + defer options.Destroy() + + if ONNX_RUNTIME_CUDA { + cudaOptions, err := onnx.NewCUDAProviderOptions() + if err != nil { + log.Printf("[ONNX] CUDA unavailable, falling back to CPU: %s\n", err) + } else { + defer cudaOptions.Destroy() + cudaOptions.Update(map[string]string{ + "cudnn_conv_algo_search": "DEFAULT", // use the only working frontend + }) + options.AppendExecutionProviderCUDA(cudaOptions) + } + } + + // Initialize Model + session, err := onnx.NewDynamicAdvancedSessionWithONNXData( + ONNX_MODEL, + []string{"input"}, + []string{"prediction"}, + options, + ) + if err != nil { + log.Printf("[ONNX] Failed to load model: %s\n", err) + return + } + ONNX_SESSION = session + + // Test Model with Dummy Data + dummy := make([]float32, MODEL_SIZE*MODEL_SIZE*3) + if _, err := ModelClassifyTensorBatch(dummy, 1); err != nil { + log.Printf("[ONNX] Failed to initialize model: %s\n", err) + return + } + + await.Add(1) + go func() { + defer await.Done() + <-stop.Done() + ONNX_SESSION.Destroy() + ONNX_SESSION = nil + onnx.DestroyEnvironment() + log.Println("[ONNX] Closed") + }() + + log.Printf("[ONNX] Model ready in %s\n", time.Since(t)) +} + +// Cast Predictions on a Tensor using the NSFW Model +func ModelClassifyTensorBatch(data []float32, count int) ([]ClassifyResults, error) { + + inputTensor, err := onnx.NewTensor( + onnx.NewShape(int64(count), MODEL_SIZE, MODEL_SIZE, 3), + data, + ) + if err != nil { + return nil, err + } + defer inputTensor.Destroy() + + outputs := []onnx.ArbitraryTensor{nil} + if err := ONNX_SESSION.Run([]onnx.ArbitraryTensor{inputTensor}, outputs); err != nil { + return nil, err + } + defer outputs[0].Destroy() + + raw := outputs[0].(*onnx.Tensor[float32]).GetData() + results := make([]ClassifyResults, count) + + for i := range results { + base := i * 5 + results[i] = ClassifyResults{ + Rating: (raw[base+1] + raw[base+3] + (raw[base+4] * 0.9)), + Drawing: raw[base+0], + Hentai: raw[base+1], + Neutral: raw[base+2], + Porn: raw[base+3], + Sexy: raw[base+4], + } + } + + return results, nil +} + +// Classify an Image returning true if it's considered safe +func ModelClassifyImage(someImage image.Image) (ClassifyResults, error) { + + // Resize Image to Usable Size + resized := image.NewRGBA(image.Rect(0, 0, MODEL_SIZE, MODEL_SIZE)) + draw.CatmullRom.Scale(resized, resized.Rect, someImage, someImage.Bounds(), draw.Over, nil) + + // Convert Pixel Data into Normalized Floats + var tensorCap = MODEL_SIZE * MODEL_SIZE * 3 + var tensorData = make([]float32, 0, tensorCap) + for y := 0; y < MODEL_SIZE; y++ { + for x := 0; x < MODEL_SIZE; x++ { + r, g, b, _ := resized.At(x, y).RGBA() + tensorData = append(tensorData, float32(r)/65535, float32(g)/65535, float32(b)/65535) + } + } + + // Create Tensor, reshape it, then classify + results, err := ModelClassifyTensorBatch(tensorData, 1) + if err != nil { + return ClassifyResults{}, err + } + + // Return Results + // Drawing[0], Hentai[1], Neutral[2], Porn[3], Sexy[4] + return results[0], nil +} diff --git a/nsfw.onnx b/nsfw.onnx new file mode 100644 index 0000000..50597a8 Binary files /dev/null and b/nsfw.onnx differ