cabinet/internal/router/rate.go

108 lines
2.4 KiB
Go

package router
import (
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
)
type Limit struct {
rpm int
spm int
burst int
mpr int
request map[string]*rate.Limiter
size map[string]*rate.Limiter
mx sync.Mutex
}
func NewLimit(requestPerMinute, sizePerMinute, burstSize, memPerRequest int) *Limit {
return &Limit{
rpm: requestPerMinute,
spm: sizePerMinute,
burst: burstSize,
mpr: memPerRequest,
request: make(map[string]*rate.Limiter),
size: make(map[string]*rate.Limiter),
mx: sync.Mutex{},
}
}
func (l *Limit) file(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if l.hasHitRequestLimit(r.RemoteAddr) {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
if err := r.ParseMultipartForm(int64(l.mpr)); err != nil {
http.Error(w, "Invalid multipart form", http.StatusBadRequest)
return
}
files := r.MultipartForm.File["file"]
if len(files) != 1 {
http.Error(w, "Expected exactly one file", http.StatusBadRequest)
return
}
if l.hasHitSizeLimit(r.RemoteAddr, files[0].Size) {
http.Error(w, "File too big", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
func (l *Limit) redirect(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if l.hasHitRequestLimit(r.RemoteAddr) {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// hasHitRequestLimit returns if the requested remote address has reached the request limit.
func (l *Limit) hasHitRequestLimit(remoteAddr string) bool {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return true
}
l.mx.Lock()
defer l.mx.Unlock()
if l.request[ip] == nil {
l.request[ip] = rate.NewLimiter(rate.Limit(l.rpm/60), l.rpm)
l.size[ip] = rate.NewLimiter(rate.Limit(l.spm/60), l.burst)
}
return !l.request[ip].Allow()
}
// hasHitSizeLimit returns if the requested remote address has reached the size limit.
func (l *Limit) hasHitSizeLimit(remoteAddr string, size int64) bool {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return true
}
l.mx.Lock()
defer l.mx.Unlock()
if l.request[ip] == nil {
l.request[ip] = rate.NewLimiter(rate.Limit(l.rpm/60), l.rpm)
l.size[ip] = rate.NewLimiter(rate.Limit(l.spm/60), l.burst)
}
return !l.size[ip].AllowN(time.Now(), int(size))
}