Initial Release
This commit is contained in:
@@ -0,0 +1,440 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/gif"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
onnx "github.com/yalue/onnxruntime_go"
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
"golang.org/x/image/webp"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type ImageType string
|
||||
|
||||
const (
|
||||
IMAGE_OTHER ImageType = "UNKNOWN"
|
||||
IMAGE_WEBP ImageType = "WEBP"
|
||||
IMAGE_JPEG ImageType = "JPG"
|
||||
IMAGE_PNG ImageType = "PNG"
|
||||
IMAGE_GIF ImageType = "GIF"
|
||||
MODEL_SIZE = 224
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed nsfw.onnx
|
||||
ONNX_MODEL []byte
|
||||
ONNX_SESSION *onnx.DynamicAdvancedSession
|
||||
ONNX_RUNTIME_PATH = os.Getenv("ONNX_RUNTIME_PATH")
|
||||
ONNX_RUNTIME_CUDA = os.Getenv("ONNX_RUNTIME_CUDA") != ""
|
||||
HTTP_ADDRESS = os.Getenv("HTTP_ADDRESS")
|
||||
HTTP_SEMAPHORE = semaphore.NewWeighted(16)
|
||||
HTTP_MAX_BODY_BYTES int64 = 16 * 1024 * 1024 // 16mb
|
||||
MODEL_THRESHOLD float32 = 0.7
|
||||
)
|
||||
|
||||
func init() {
|
||||
if HTTP_ADDRESS == "" {
|
||||
HTTP_ADDRESS = "127.0.0.1:9000"
|
||||
}
|
||||
|
||||
if plaintext := os.Getenv("HTTP_CONCURRENCY"); plaintext != "" {
|
||||
val, err := strconv.ParseInt(plaintext, 10, 64)
|
||||
if err == nil {
|
||||
HTTP_SEMAPHORE = semaphore.NewWeighted(val)
|
||||
}
|
||||
}
|
||||
|
||||
if plaintext := os.Getenv("HTTP_MAX_BODY_BYTES"); plaintext != "" {
|
||||
val, err := strconv.ParseInt(plaintext, 10, 64)
|
||||
if err == nil {
|
||||
HTTP_MAX_BODY_BYTES = val
|
||||
}
|
||||
}
|
||||
|
||||
if plaintext := os.Getenv("MODEL_THRESHOLD"); plaintext != "" {
|
||||
val, err := strconv.ParseFloat(plaintext, 32)
|
||||
if err == nil {
|
||||
MODEL_THRESHOLD = float32(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ClassifyResults struct {
|
||||
Rating float32
|
||||
Drawing float32
|
||||
Hentai float32
|
||||
Neutral float32
|
||||
Porn float32
|
||||
Sexy float32
|
||||
}
|
||||
|
||||
func main() {
|
||||
time.Local = time.UTC
|
||||
|
||||
// Startup Services
|
||||
var stopCtx, stop = context.WithCancel(context.Background())
|
||||
var stopWg sync.WaitGroup
|
||||
|
||||
SetupModel(stopCtx, &stopWg)
|
||||
go SetupHTTP(stopCtx, &stopWg)
|
||||
|
||||
// Await Shutdown Signal
|
||||
cancel := make(chan os.Signal, 1)
|
||||
signal.Notify(cancel, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-cancel
|
||||
stop()
|
||||
|
||||
// Begin Shutdown Process
|
||||
timeout, finish := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer finish()
|
||||
go func() {
|
||||
<-timeout.Done()
|
||||
if timeout.Err() == context.DeadlineExceeded {
|
||||
log.Fatalln("[MAIN] Shutdown Deadline Exceeded")
|
||||
}
|
||||
}()
|
||||
stopWg.Wait()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func SetupHTTP(stop context.Context, await *sync.WaitGroup) {
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
var elapsedProbe, elapsedDecode, elapsedClassify int64
|
||||
t := time.Now()
|
||||
|
||||
switch r.Method {
|
||||
|
||||
case http.MethodPost:
|
||||
// Request Validation
|
||||
if r.ContentLength > HTTP_MAX_BODY_BYTES {
|
||||
http.Error(w, "Content Too Large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
r.Body = http.MaxBytesReader(w, r.Body, HTTP_MAX_BODY_BYTES)
|
||||
|
||||
var d []byte
|
||||
if raw, err := io.ReadAll(r.Body); err != nil {
|
||||
http.Error(w, "Body Error", http.StatusLengthRequired)
|
||||
return
|
||||
} else {
|
||||
d = raw
|
||||
}
|
||||
|
||||
// Image Validation
|
||||
var imageType ImageType
|
||||
var imageInfo image.Config
|
||||
var imageData image.Image
|
||||
var imageErr error
|
||||
|
||||
switch { // Sniffing
|
||||
case len(d) > 3 && // JPEG
|
||||
d[0] == 0xFF && d[1] == 0xD8 && d[2] == 0xFF:
|
||||
imageType = IMAGE_JPEG
|
||||
|
||||
case len(d) > 8 && // PNG
|
||||
d[0] == 0x89 && d[1] == 0x50 && d[2] == 0x4E && d[3] == 0x47 &&
|
||||
d[4] == 0x0D && d[5] == 0x0A && d[6] == 0x1A && d[7] == 0x0A:
|
||||
imageType = IMAGE_PNG
|
||||
|
||||
case len(d) > 4 && // GIF
|
||||
d[0] == 0x47 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x38:
|
||||
imageType = IMAGE_GIF
|
||||
|
||||
case len(d) > 12 && // WEBP
|
||||
d[0] == 0x52 && d[1] == 0x49 && d[2] == 0x46 && d[3] == 0x46 &&
|
||||
d[8] == 0x57 && d[9] == 0x45 && d[10] == 0x42 && d[11] == 0x50:
|
||||
imageType = IMAGE_WEBP
|
||||
|
||||
default:
|
||||
imageType = IMAGE_OTHER
|
||||
}
|
||||
|
||||
switch imageType { // Decode Header
|
||||
case IMAGE_WEBP:
|
||||
imageInfo, imageErr = webp.DecodeConfig(bytes.NewReader(d))
|
||||
case IMAGE_JPEG:
|
||||
imageInfo, imageErr = jpeg.DecodeConfig(bytes.NewReader(d))
|
||||
case IMAGE_PNG:
|
||||
imageInfo, imageErr = png.DecodeConfig(bytes.NewReader(d))
|
||||
case IMAGE_GIF:
|
||||
imageInfo, imageErr = gif.DecodeConfig(bytes.NewReader(d))
|
||||
default:
|
||||
http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType)
|
||||
return
|
||||
}
|
||||
elapsedProbe = time.Since(t).Microseconds()
|
||||
t = time.Now()
|
||||
|
||||
if imageErr != nil {
|
||||
http.Error(w, "Invalid Image Data", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if imageInfo.Width > 8192 {
|
||||
http.Error(w, "Image cannot be wider than 8192 pixels", http.StatusUnprocessableEntity)
|
||||
return
|
||||
}
|
||||
if imageInfo.Height > 4096 {
|
||||
http.Error(w, "Image cannot be taller than 4096 pixels", http.StatusUnprocessableEntity)
|
||||
return
|
||||
}
|
||||
if imageInfo.Height < 32 || imageInfo.Width < 32 {
|
||||
http.Error(w, "Image cannot be smaller than 32 pixels", http.StatusUnprocessableEntity)
|
||||
return
|
||||
}
|
||||
|
||||
switch imageType { // Decode Pixels
|
||||
case IMAGE_WEBP:
|
||||
imageData, imageErr = webp.Decode(bytes.NewReader(d))
|
||||
case IMAGE_JPEG:
|
||||
imageData, imageErr = jpeg.Decode(bytes.NewReader(d))
|
||||
case IMAGE_PNG:
|
||||
imageData, imageErr = png.Decode(bytes.NewReader(d))
|
||||
case IMAGE_GIF:
|
||||
imageData, imageErr = gif.Decode(bytes.NewReader(d))
|
||||
default:
|
||||
http.Error(w, "Unsupported Image Format", http.StatusUnsupportedMediaType)
|
||||
return
|
||||
}
|
||||
if imageErr != nil {
|
||||
http.Error(w, "Invalid Image Data", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
elapsedDecode = time.Since(t).Microseconds()
|
||||
t = time.Now()
|
||||
|
||||
// Image Classification
|
||||
results, err := ModelClassifyImage(imageData)
|
||||
if err != nil {
|
||||
http.Error(w, "Classify Error: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
elapsedClassify = time.Since(t).Microseconds()
|
||||
t = time.Now()
|
||||
|
||||
// Output Results
|
||||
output, err := json.Marshal(map[string]any{
|
||||
"allowed": results.Rating < MODEL_THRESHOLD,
|
||||
"timings": map[string]any{
|
||||
"probe": elapsedProbe,
|
||||
"decode": elapsedDecode,
|
||||
"classify": elapsedClassify,
|
||||
"total": elapsedProbe + elapsedDecode + elapsedClassify,
|
||||
},
|
||||
"image": map[string]any{
|
||||
"hash_sha256": fmt.Sprintf("%x", sha256.Sum256(d)),
|
||||
"hash_md5": fmt.Sprintf("%x", md5.Sum(d)),
|
||||
"height": imageInfo.Height,
|
||||
"width": imageInfo.Width,
|
||||
},
|
||||
"logits": map[string]any{
|
||||
"rating": math.Round(float64(results.Rating)*10000) / 100,
|
||||
"drawing": math.Round(float64(results.Drawing)*10000) / 100,
|
||||
"hentai": math.Round(float64(results.Hentai)*10000) / 100,
|
||||
"neutral": math.Round(float64(results.Neutral)*10000) / 100,
|
||||
"porn": math.Round(float64(results.Porn)*10000) / 100,
|
||||
"sexy": math.Round(float64(results.Sexy)*10000) / 100,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, "Encoding Error: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Write(output)
|
||||
fmt.Fprintf(os.Stdout, "%s\n", output)
|
||||
|
||||
// Other Methods
|
||||
case http.MethodHead, http.MethodGet:
|
||||
http.Error(w, "nsfw-service ; ><> .o( blub blub )", http.StatusOK)
|
||||
|
||||
default:
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
svr := http.Server{
|
||||
Handler: mux,
|
||||
Addr: HTTP_ADDRESS,
|
||||
MaxHeaderBytes: 4096,
|
||||
IdleTimeout: 10 * time.Second,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Shutdown Logic
|
||||
await.Add(1)
|
||||
go func() {
|
||||
defer await.Done()
|
||||
<-stop.Done()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := svr.Shutdown(shutdownCtx); err != nil {
|
||||
log.Println("[HTTP] Shutdown Failed:", err)
|
||||
}
|
||||
|
||||
log.Println("[HTTP] Server Closed")
|
||||
}()
|
||||
|
||||
// Server Startup
|
||||
log.Println("[HTTP] Listening @", HTTP_ADDRESS)
|
||||
if err := svr.ListenAndServe(); err != http.ErrServerClosed {
|
||||
log.Fatalln("[HTTP] Startup Failed:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func SetupModel(stop context.Context, await *sync.WaitGroup) {
|
||||
t := time.Now()
|
||||
|
||||
// Initialize Environment
|
||||
onnx.SetSharedLibraryPath(ONNX_RUNTIME_PATH)
|
||||
if err := onnx.InitializeEnvironment(); err != nil {
|
||||
log.Fatalf("[ONNX] Failed to initialize ONNX Runtime: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize Settings
|
||||
options, err := onnx.NewSessionOptions()
|
||||
if err != nil {
|
||||
log.Fatalf("[ONNX] Failed to create session options: %s\n", err)
|
||||
return
|
||||
}
|
||||
defer options.Destroy()
|
||||
|
||||
if ONNX_RUNTIME_CUDA {
|
||||
cudaOptions, err := onnx.NewCUDAProviderOptions()
|
||||
if err != nil {
|
||||
log.Printf("[ONNX] CUDA unavailable, falling back to CPU: %s\n", err)
|
||||
} else {
|
||||
defer cudaOptions.Destroy()
|
||||
cudaOptions.Update(map[string]string{
|
||||
"cudnn_conv_algo_search": "DEFAULT", // use the only working frontend
|
||||
})
|
||||
options.AppendExecutionProviderCUDA(cudaOptions)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize Model
|
||||
session, err := onnx.NewDynamicAdvancedSessionWithONNXData(
|
||||
ONNX_MODEL,
|
||||
[]string{"input"},
|
||||
[]string{"prediction"},
|
||||
options,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[ONNX] Failed to load model: %s\n", err)
|
||||
return
|
||||
}
|
||||
ONNX_SESSION = session
|
||||
|
||||
// Test Model with Dummy Data
|
||||
dummy := make([]float32, MODEL_SIZE*MODEL_SIZE*3)
|
||||
if _, err := ModelClassifyTensorBatch(dummy, 1); err != nil {
|
||||
log.Printf("[ONNX] Failed to initialize model: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
await.Add(1)
|
||||
go func() {
|
||||
defer await.Done()
|
||||
<-stop.Done()
|
||||
ONNX_SESSION.Destroy()
|
||||
ONNX_SESSION = nil
|
||||
onnx.DestroyEnvironment()
|
||||
log.Println("[ONNX] Closed")
|
||||
}()
|
||||
|
||||
log.Printf("[ONNX] Model ready in %s\n", time.Since(t))
|
||||
}
|
||||
|
||||
// Cast Predictions on a Tensor using the NSFW Model
|
||||
func ModelClassifyTensorBatch(data []float32, count int) ([]ClassifyResults, error) {
|
||||
|
||||
inputTensor, err := onnx.NewTensor(
|
||||
onnx.NewShape(int64(count), MODEL_SIZE, MODEL_SIZE, 3),
|
||||
data,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer inputTensor.Destroy()
|
||||
|
||||
outputs := []onnx.ArbitraryTensor{nil}
|
||||
if err := ONNX_SESSION.Run([]onnx.ArbitraryTensor{inputTensor}, outputs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer outputs[0].Destroy()
|
||||
|
||||
raw := outputs[0].(*onnx.Tensor[float32]).GetData()
|
||||
results := make([]ClassifyResults, count)
|
||||
|
||||
for i := range results {
|
||||
base := i * 5
|
||||
results[i] = ClassifyResults{
|
||||
Rating: (raw[base+1] + raw[base+3] + (raw[base+4] * 0.9)),
|
||||
Drawing: raw[base+0],
|
||||
Hentai: raw[base+1],
|
||||
Neutral: raw[base+2],
|
||||
Porn: raw[base+3],
|
||||
Sexy: raw[base+4],
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Classify an Image returning true if it's considered safe
|
||||
func ModelClassifyImage(someImage image.Image) (ClassifyResults, error) {
|
||||
|
||||
// Resize Image to Usable Size
|
||||
resized := image.NewRGBA(image.Rect(0, 0, MODEL_SIZE, MODEL_SIZE))
|
||||
draw.CatmullRom.Scale(resized, resized.Rect, someImage, someImage.Bounds(), draw.Over, nil)
|
||||
|
||||
// Convert Pixel Data into Normalized Floats
|
||||
var tensorCap = MODEL_SIZE * MODEL_SIZE * 3
|
||||
var tensorData = make([]float32, 0, tensorCap)
|
||||
for y := 0; y < MODEL_SIZE; y++ {
|
||||
for x := 0; x < MODEL_SIZE; x++ {
|
||||
r, g, b, _ := resized.At(x, y).RGBA()
|
||||
tensorData = append(tensorData, float32(r)/65535, float32(g)/65535, float32(b)/65535)
|
||||
}
|
||||
}
|
||||
|
||||
// Create Tensor, reshape it, then classify
|
||||
results, err := ModelClassifyTensorBatch(tensorData, 1)
|
||||
if err != nil {
|
||||
return ClassifyResults{}, err
|
||||
}
|
||||
|
||||
// Return Results
|
||||
// Drawing[0], Hentai[1], Neutral[2], Porn[3], Sexy[4]
|
||||
return results[0], nil
|
||||
}
|
||||
Reference in New Issue
Block a user