108 lines
2.4 KiB
Go
108 lines
2.4 KiB
Go
package router
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
type Limit struct {
|
|
rpm int
|
|
spm int
|
|
burst int
|
|
mpr int
|
|
|
|
request map[string]*rate.Limiter
|
|
size map[string]*rate.Limiter
|
|
mx sync.Mutex
|
|
}
|
|
|
|
func NewLimit(requestPerMinute, sizePerMinute, burstSize, memPerRequest int) *Limit {
|
|
return &Limit{
|
|
rpm: requestPerMinute,
|
|
spm: sizePerMinute,
|
|
burst: burstSize,
|
|
mpr: memPerRequest,
|
|
request: make(map[string]*rate.Limiter),
|
|
size: make(map[string]*rate.Limiter),
|
|
mx: sync.Mutex{},
|
|
}
|
|
}
|
|
|
|
func (l *Limit) file(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if l.hasHitRequestLimit(r.RemoteAddr) {
|
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
if err := r.ParseMultipartForm(int64(l.mpr)); err != nil {
|
|
http.Error(w, "Invalid multipart form", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
files := r.MultipartForm.File["file"]
|
|
if len(files) != 1 {
|
|
http.Error(w, "Expected exactly one file", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if l.hasHitSizeLimit(r.RemoteAddr, files[0].Size) {
|
|
http.Error(w, "File too big", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (l *Limit) redirect(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if l.hasHitRequestLimit(r.RemoteAddr) {
|
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// hasHitRequestLimit returns if the requested remote address has reached the request limit.
|
|
func (l *Limit) hasHitRequestLimit(remoteAddr string) bool {
|
|
ip, _, err := net.SplitHostPort(remoteAddr)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
|
|
l.mx.Lock()
|
|
defer l.mx.Unlock()
|
|
|
|
if l.request[ip] == nil {
|
|
l.request[ip] = rate.NewLimiter(rate.Limit(l.rpm/60), l.rpm)
|
|
l.size[ip] = rate.NewLimiter(rate.Limit(l.spm/60), l.burst)
|
|
}
|
|
|
|
return !l.request[ip].Allow()
|
|
}
|
|
|
|
// hasHitSizeLimit returns if the requested remote address has reached the size limit.
|
|
func (l *Limit) hasHitSizeLimit(remoteAddr string, size int64) bool {
|
|
ip, _, err := net.SplitHostPort(remoteAddr)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
|
|
l.mx.Lock()
|
|
defer l.mx.Unlock()
|
|
|
|
if l.request[ip] == nil {
|
|
l.request[ip] = rate.NewLimiter(rate.Limit(l.rpm/60), l.rpm)
|
|
l.size[ip] = rate.NewLimiter(rate.Limit(l.spm/60), l.burst)
|
|
}
|
|
|
|
return !l.size[ip].AllowN(time.Now(), int(size))
|
|
}
|