Initial Release
This commit is contained in:
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2026 bakonpancakz
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
----------------
|
||||||
|
nsfw-service
|
||||||
|
----------------
|
||||||
|
|
||||||
|
[ Usage ]
|
||||||
|
|
||||||
|
Send an HTTP POST request to port 9000 with the request body containing raw image data.
|
||||||
|
|
||||||
|
Supported formats: WEBP, PNG, GIF, JPEG.
|
||||||
|
|
||||||
|
Animated WEBPs and PNGs are not supported and will return a 415 status code.
|
||||||
|
For GIFs, only the first frame is used.
|
||||||
|
|
||||||
|
--- Response Status Codes ---
|
||||||
|
|
||||||
|
400 – Invalid image data
|
||||||
|
405 – Method not allowed (Use POST for inference and HEAD or GET for health)
|
||||||
|
411 – Failed to read request body
|
||||||
|
413 – Request body too large
|
||||||
|
415 – Unsupported image type
|
||||||
|
422 – Invalid image dimensions (>8192x4096px or <32px)
|
||||||
|
500 – Server error, response body is plaintext
|
||||||
|
|
||||||
|
--- Example Response Body ---
|
||||||
|
|
||||||
|
{
|
||||||
|
"allowed": true,
|
||||||
|
"image": {
|
||||||
|
"hash_md5": "50f64bdb0f11d281505bce990e805569",
|
||||||
|
"hash_sha256": "f2a94429ccd5e5467f6a1f2bd166d8def75ced242a59b6f71b659e827c008b75",
|
||||||
|
"height": 850,
|
||||||
|
"width": 850
|
||||||
|
},
|
||||||
|
"logits": {
|
||||||
|
"drawing": 0.85562634,
|
||||||
|
"hentai": 0.14418653,
|
||||||
|
"neutral": 0.000010294689,
|
||||||
|
"porn": 0.00014420893,
|
||||||
|
"sexy": 0.00003254598
|
||||||
|
},
|
||||||
|
"timings": {
|
||||||
|
"classify": 7125,
|
||||||
|
"decode": 8746,
|
||||||
|
"probe": 0,
|
||||||
|
"total": 15871
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[ Config ]
|
||||||
|
|
||||||
|
Configure the service using environment variables:
|
||||||
|
|
||||||
|
| Name | Default | Description
|
||||||
|
| HTTP_ADDRESS | 127.0.0.1:9000 | Address and Port to use for requests
|
||||||
|
| ONNX_RUNTIME_PATH | <empty> | Set to anything to enable CUDA, requires you supply your own ONNX runtime
|
||||||
|
| ONNX_RUNTIME_PATH | <empty> | Path to the ONNX runtime library, either a onnxruntime.dll or libonnxruntime.so
|
||||||
|
| HTTP_CONCURRENCY | 16 | Maximum concurrent api requests
|
||||||
|
| HTTP_MAX_BODY_BYTES | 16777216 | Maximum request size in bytes, default is 16MB
|
||||||
|
| MODEL_THRESHOLD | 0.7 | Threshold before an image is consider inappropriate
|
||||||
|
|
||||||
|
[ Credits ]
|
||||||
|
|
||||||
|
Uses the following AI model: https://github.com/GantMan/nsfw_model
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
module service-nsfw
|
||||||
|
|
||||||
|
go 1.26
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/yalue/onnxruntime_go v1.28.0
|
||||||
|
golang.org/x/image v0.39.0
|
||||||
|
golang.org/x/sync v0.20.0
|
||||||
|
)
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
github.com/yalue/onnxruntime_go v1.28.0 h1:ximEqgLtBhb3DY0IHyR0GWGpGJ+xef85qxgPwa/iotg=
|
||||||
|
github.com/yalue/onnxruntime_go v1.28.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
|
||||||
|
golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
|
||||||
|
golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
|
||||||
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
@@ -0,0 +1,440 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/md5"
|
||||||
|
"crypto/sha256"
|
||||||
|
_ "embed"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/gif"
|
||||||
|
"image/jpeg"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
onnx "github.com/yalue/onnxruntime_go"
|
||||||
|
|
||||||
|
"golang.org/x/image/draw"
|
||||||
|
"golang.org/x/image/webp"
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
IMAGE_OTHER ImageType = "UNKNOWN"
|
||||||
|
IMAGE_WEBP ImageType = "WEBP"
|
||||||
|
IMAGE_JPEG ImageType = "JPG"
|
||||||
|
IMAGE_PNG ImageType = "PNG"
|
||||||
|
IMAGE_GIF ImageType = "GIF"
|
||||||
|
MODEL_SIZE = 224
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed nsfw.onnx
|
||||||
|
ONNX_MODEL []byte
|
||||||
|
ONNX_SESSION *onnx.DynamicAdvancedSession
|
||||||
|
ONNX_RUNTIME_PATH = os.Getenv("ONNX_RUNTIME_PATH")
|
||||||
|
ONNX_RUNTIME_CUDA = os.Getenv("ONNX_RUNTIME_CUDA") != ""
|
||||||
|
HTTP_ADDRESS = os.Getenv("HTTP_ADDRESS")
|
||||||
|
HTTP_SEMAPHORE = semaphore.NewWeighted(16)
|
||||||
|
HTTP_MAX_BODY_BYTES int64 = 16 * 1024 * 1024 // 16mb
|
||||||
|
MODEL_THRESHOLD float32 = 0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if HTTP_ADDRESS == "" {
|
||||||
|
HTTP_ADDRESS = "127.0.0.1:9000"
|
||||||
|
}
|
||||||
|
|
||||||
|
if plaintext := os.Getenv("HTTP_CONCURRENCY"); plaintext != "" {
|
||||||
|
val, err := strconv.ParseInt(plaintext, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
HTTP_SEMAPHORE = semaphore.NewWeighted(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if plaintext := os.Getenv("HTTP_MAX_BODY_BYTES"); plaintext != "" {
|
||||||
|
val, err := strconv.ParseInt(plaintext, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
HTTP_MAX_BODY_BYTES = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if plaintext := os.Getenv("MODEL_THRESHOLD"); plaintext != "" {
|
||||||
|
val, err := strconv.ParseFloat(plaintext, 32)
|
||||||
|
if err == nil {
|
||||||
|
MODEL_THRESHOLD = float32(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClassifyResults struct {
|
||||||
|
Rating float32
|
||||||
|
Drawing float32
|
||||||
|
Hentai float32
|
||||||
|
Neutral float32
|
||||||
|
Porn float32
|
||||||
|
Sexy float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
time.Local = time.UTC
|
||||||
|
|
||||||
|
// Startup Services
|
||||||
|
var stopCtx, stop = context.WithCancel(context.Background())
|
||||||
|
var stopWg sync.WaitGroup
|
||||||
|
|
||||||
|
SetupModel(stopCtx, &stopWg)
|
||||||
|
go SetupHTTP(stopCtx, &stopWg)
|
||||||
|
|
||||||
|
// Await Shutdown Signal
|
||||||
|
cancel := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(cancel, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-cancel
|
||||||
|
stop()
|
||||||
|
|
||||||
|
// Begin Shutdown Process
|
||||||
|
timeout, finish := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer finish()
|
||||||
|
go func() {
|
||||||
|
<-timeout.Done()
|
||||||
|
if timeout.Err() == context.DeadlineExceeded {
|
||||||
|
log.Fatalln("[MAIN] Shutdown Deadline Exceeded")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
stopWg.Wait()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetupHTTP(stop context.Context, await *sync.WaitGroup) {
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var elapsedProbe, elapsedDecode, elapsedClassify int64
|
||||||
|
t := time.Now()
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
|
||||||
|
case http.MethodPost:
|
||||||
|
// Request Validation
|
||||||
|
if r.ContentLength > HTTP_MAX_BODY_BYTES {
|
||||||
|
http.Error(w, "Content Too Large", http.StatusRequestEntityTooLarge)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, HTTP_MAX_BODY_BYTES)
|
||||||
|
|
||||||
|
var d []byte
|
||||||
|
if raw, err := io.ReadAll(r.Body); err != nil {
|
||||||
|
http.Error(w, "Body Error", http.StatusLengthRequired)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
d = raw
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image Validation
|
||||||
|
var imageType ImageType
|
||||||
|
var imageInfo image.Config
|
||||||
|
var imageData image.Image
|
||||||
|
var imageErr error
|
||||||
|
|
||||||
|
switch { // Sniffing
|
||||||
|
case len(d) > 3 && // JPEG
|
||||||
|
d[0] == 0xFF && d[1] == 0xD8 && d[2] == 0xFF:
|
||||||
|
imageType = IMAGE_JPEG
|
||||||
|
|
||||||
|
case len(d) > 8 && // PNG
|
||||||
|
d[0] == 0x89 && d[1] == 0x50 && d[2] == 0x4E && d[3] == 0x47 &&
|
||||||
|
d[4] == 0x0D && d[5] == 0x0A && d[6] == 0x1A && d[7] == 0x0A:
|
||||||
|
imageType = IMAGE_PNG
|
||||||
|
|
||||||
|
case len(d) > 4 && // GIF
|
||||||
|
d[0] == 0x47 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x38:
|
||||||
|
imageType = IMAGE_GIF
|
||||||
|
|
||||||
|
case len(d) > 12 && // WEBP
|
||||||
|
d[0] == 0x52 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x46 &&
|
||||||
|
d[8] == 0x57 && d[9] == 0x45 && d[10] == 0x42 && d[11] == 0x50:
|
||||||
|
imageType = IMAGE_WEBP
|
||||||
|
|
||||||
|
default:
|
||||||
|
imageType = IMAGE_OTHER
|
||||||
|
}
|
||||||
|
|
||||||
|
switch imageType { // Decode Header
|
||||||
|
case IMAGE_WEBP:
|
||||||
|
imageInfo, imageErr = webp.DecodeConfig(bytes.NewReader(d))
|
||||||
|
case IMAGE_JPEG:
|
||||||
|
imageInfo, imageErr = jpeg.DecodeConfig(bytes.NewReader(d))
|
||||||
|
case IMAGE_PNG:
|
||||||
|
imageInfo, imageErr = png.DecodeConfig(bytes.NewReader(d))
|
||||||
|
case IMAGE_GIF:
|
||||||
|
imageInfo, imageErr = gif.DecodeConfig(bytes.NewReader(d))
|
||||||
|
default:
|
||||||
|
http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
elapsedProbe = time.Since(t).Microseconds()
|
||||||
|
t = time.Now()
|
||||||
|
|
||||||
|
if imageErr != nil {
|
||||||
|
http.Error(w, "Invalid Image Data", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageInfo.Width > 8192 {
|
||||||
|
http.Error(w, "Image cannot be wider than 8192 pixels", http.StatusUnprocessableEntity)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if imageInfo.Height > 4096 {
|
||||||
|
http.Error(w, "Image cannot be taller than 4096 pixels", http.StatusUnprocessableEntity)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if imageInfo.Height < 32 || imageInfo.Width < 32 {
|
||||||
|
http.Error(w, "Image cannot be smaller than 32 pixels", http.StatusUnprocessableEntity)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch imageType { // Decode Pixels
|
||||||
|
case IMAGE_WEBP:
|
||||||
|
imageData, imageErr = webp.Decode(bytes.NewReader(d))
|
||||||
|
case IMAGE_JPEG:
|
||||||
|
imageData, imageErr = jpeg.Decode(bytes.NewReader(d))
|
||||||
|
case IMAGE_PNG:
|
||||||
|
imageData, imageErr = png.Decode(bytes.NewReader(d))
|
||||||
|
case IMAGE_GIF:
|
||||||
|
imageData, imageErr = gif.Decode(bytes.NewReader(d))
|
||||||
|
default:
|
||||||
|
http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if imageErr != nil {
|
||||||
|
http.Error(w, "Invalid Image Data", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
elapsedDecode = time.Since(t).Microseconds()
|
||||||
|
t = time.Now()
|
||||||
|
|
||||||
|
// Image Classification
|
||||||
|
results, err := ModelClassifyImage(imageData)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Classify Error: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
elapsedClassify = time.Since(t).Microseconds()
|
||||||
|
t = time.Now()
|
||||||
|
|
||||||
|
// Output Results
|
||||||
|
output, err := json.Marshal(map[string]any{
|
||||||
|
"allowed": results.Rating < MODEL_THRESHOLD,
|
||||||
|
"timings": map[string]any{
|
||||||
|
"probe": elapsedProbe,
|
||||||
|
"decode": elapsedDecode,
|
||||||
|
"classify": elapsedClassify,
|
||||||
|
"total": elapsedProbe + elapsedDecode + elapsedClassify,
|
||||||
|
},
|
||||||
|
"image": map[string]any{
|
||||||
|
"hash_sha256": fmt.Sprintf("%x", sha256.Sum256(d)),
|
||||||
|
"hash_md5": fmt.Sprintf("%x", md5.Sum(d)),
|
||||||
|
"height": imageInfo.Height,
|
||||||
|
"width": imageInfo.Width,
|
||||||
|
},
|
||||||
|
"logits": map[string]any{
|
||||||
|
"rating": math.Round(float64(results.Rating)*10000) / 100,
|
||||||
|
"drawing": math.Round(float64(results.Drawing)*10000) / 100,
|
||||||
|
"hentai": math.Round(float64(results.Hentai)*10000) / 100,
|
||||||
|
"neutral": math.Round(float64(results.Neutral)*10000) / 100,
|
||||||
|
"porn": math.Round(float64(results.Porn)*10000) / 100,
|
||||||
|
"sexy": math.Round(float64(results.Sexy)*10000) / 100,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Encoding Error: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.Write(output)
|
||||||
|
fmt.Fprintf(os.Stdout, "%s\n", output)
|
||||||
|
|
||||||
|
// Other Methods
|
||||||
|
case http.MethodHead, http.MethodGet:
|
||||||
|
http.Error(w, "nsfw-service ; ><> .o( blub blub )", http.StatusOK)
|
||||||
|
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
svr := http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
Addr: HTTP_ADDRESS,
|
||||||
|
MaxHeaderBytes: 4096,
|
||||||
|
IdleTimeout: 10 * time.Second,
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown Logic
|
||||||
|
await.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer await.Done()
|
||||||
|
<-stop.Done()
|
||||||
|
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := svr.Shutdown(shutdownCtx); err != nil {
|
||||||
|
log.Println("[HTTP] Shutdown Failed:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("[HTTP] Server Closed")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Server Startup
|
||||||
|
log.Println("[HTTP] Listening @", HTTP_ADDRESS)
|
||||||
|
if err := svr.ListenAndServe(); err != http.ErrServerClosed {
|
||||||
|
log.Fatalln("[HTTP] Startup Failed:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetupModel(stop context.Context, await *sync.WaitGroup) {
|
||||||
|
t := time.Now()
|
||||||
|
|
||||||
|
// Initialize Environment
|
||||||
|
onnx.SetSharedLibraryPath(ONNX_RUNTIME_PATH)
|
||||||
|
if err := onnx.InitializeEnvironment(); err != nil {
|
||||||
|
log.Fatalf("[ONNX] Failed to initialize ONNX Runtime: %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize Settings
|
||||||
|
options, err := onnx.NewSessionOptions()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("[ONNX] Failed to create session options: %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer options.Destroy()
|
||||||
|
|
||||||
|
if ONNX_RUNTIME_CUDA {
|
||||||
|
cudaOptions, err := onnx.NewCUDAProviderOptions()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[ONNX] CUDA unavailable, falling back to CPU: %s\n", err)
|
||||||
|
} else {
|
||||||
|
defer cudaOptions.Destroy()
|
||||||
|
cudaOptions.Update(map[string]string{
|
||||||
|
"cudnn_conv_algo_search": "DEFAULT", // use the only working frontend
|
||||||
|
})
|
||||||
|
options.AppendExecutionProviderCUDA(cudaOptions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize Model
|
||||||
|
session, err := onnx.NewDynamicAdvancedSessionWithONNXData(
|
||||||
|
ONNX_MODEL,
|
||||||
|
[]string{"input"},
|
||||||
|
[]string{"prediction"},
|
||||||
|
options,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[ONNX] Failed to load model: %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ONNX_SESSION = session
|
||||||
|
|
||||||
|
// Test Model with Dummy Data
|
||||||
|
dummy := make([]float32, MODEL_SIZE*MODEL_SIZE*3)
|
||||||
|
if _, err := ModelClassifyTensorBatch(dummy, 1); err != nil {
|
||||||
|
log.Printf("[ONNX] Failed to initialize model: %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
await.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer await.Done()
|
||||||
|
<-stop.Done()
|
||||||
|
ONNX_SESSION.Destroy()
|
||||||
|
ONNX_SESSION = nil
|
||||||
|
onnx.DestroyEnvironment()
|
||||||
|
log.Println("[ONNX] Closed")
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Printf("[ONNX] Model ready in %s\n", time.Since(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cast Predictions on a Tensor using the NSFW Model
|
||||||
|
func ModelClassifyTensorBatch(data []float32, count int) ([]ClassifyResults, error) {
|
||||||
|
|
||||||
|
inputTensor, err := onnx.NewTensor(
|
||||||
|
onnx.NewShape(int64(count), MODEL_SIZE, MODEL_SIZE, 3),
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer inputTensor.Destroy()
|
||||||
|
|
||||||
|
outputs := []onnx.ArbitraryTensor{nil}
|
||||||
|
if err := ONNX_SESSION.Run([]onnx.ArbitraryTensor{inputTensor}, outputs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer outputs[0].Destroy()
|
||||||
|
|
||||||
|
raw := outputs[0].(*onnx.Tensor[float32]).GetData()
|
||||||
|
results := make([]ClassifyResults, count)
|
||||||
|
|
||||||
|
for i := range results {
|
||||||
|
base := i * 5
|
||||||
|
results[i] = ClassifyResults{
|
||||||
|
Rating: (raw[base+1] + raw[base+3] + (raw[base+4] * 0.9)),
|
||||||
|
Drawing: raw[base+0],
|
||||||
|
Hentai: raw[base+1],
|
||||||
|
Neutral: raw[base+2],
|
||||||
|
Porn: raw[base+3],
|
||||||
|
Sexy: raw[base+4],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Classify an Image returning true if it's considered safe
|
||||||
|
func ModelClassifyImage(someImage image.Image) (ClassifyResults, error) {
|
||||||
|
|
||||||
|
// Resize Image to Usable Size
|
||||||
|
resized := image.NewRGBA(image.Rect(0, 0, MODEL_SIZE, MODEL_SIZE))
|
||||||
|
draw.CatmullRom.Scale(resized, resized.Rect, someImage, someImage.Bounds(), draw.Over, nil)
|
||||||
|
|
||||||
|
// Convert Pixel Data into Normalized Floats
|
||||||
|
var tensorCap = MODEL_SIZE * MODEL_SIZE * 3
|
||||||
|
var tensorData = make([]float32, 0, tensorCap)
|
||||||
|
for y := 0; y < MODEL_SIZE; y++ {
|
||||||
|
for x := 0; x < MODEL_SIZE; x++ {
|
||||||
|
r, g, b, _ := resized.At(x, y).RGBA()
|
||||||
|
tensorData = append(tensorData, float32(r)/65535, float32(g)/65535, float32(b)/65535)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Tensor, reshape it, then classify
|
||||||
|
results, err := ModelClassifyTensorBatch(tensorData, 1)
|
||||||
|
if err != nil {
|
||||||
|
return ClassifyResults{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return Results
|
||||||
|
// Drawing[0], Hentai[1], Neutral[2], Porn[3], Sexy[4]
|
||||||
|
return results[0], nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user