package router import ( "net" "net/http" "sync" "time" "golang.org/x/time/rate" ) type Limit struct { rpm int spm int burst int mpr int request map[string]*rate.Limiter size map[string]*rate.Limiter mx sync.Mutex } func NewLimit(requestPerMinute, sizePerMinute, burstSize, memPerRequest int) *Limit { return &Limit{ rpm: requestPerMinute, spm: sizePerMinute, burst: burstSize, mpr: memPerRequest, request: make(map[string]*rate.Limiter), size: make(map[string]*rate.Limiter), mx: sync.Mutex{}, } } func (l *Limit) file(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if l.hasHitRequestLimit(r.RemoteAddr) { http.Error(w, "Too many requests", http.StatusTooManyRequests) return } if err := r.ParseMultipartForm(int64(l.mpr)); err != nil { http.Error(w, "Invalid multipart form", http.StatusBadRequest) return } files := r.MultipartForm.File["file"] if len(files) != 1 { http.Error(w, "Expected exactly one file", http.StatusBadRequest) return } if l.hasHitSizeLimit(r.RemoteAddr, files[0].Size) { http.Error(w, "File too big", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } func (l *Limit) redirect(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if l.hasHitRequestLimit(r.RemoteAddr) { http.Error(w, "Too many requests", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // hasHitRequestLimit returns if the requested remote address has reached the request limit. func (l *Limit) hasHitRequestLimit(remoteAddr string) bool { ip, _, err := net.SplitHostPort(remoteAddr) if err != nil { return true } l.mx.Lock() defer l.mx.Unlock() if l.request[ip] == nil { l.request[ip] = rate.NewLimiter(rate.Limit(l.rpm/60), l.rpm) l.size[ip] = rate.NewLimiter(rate.Limit(l.spm/60), l.burst) } return !l.request[ip].Allow() } // hasHitSizeLimit returns if the requested remote address has reached the size limit. func (l *Limit) hasHitSizeLimit(remoteAddr string, size int64) bool { ip, _, err := net.SplitHostPort(remoteAddr) if err != nil { return true } l.mx.Lock() defer l.mx.Unlock() if l.request[ip] == nil { l.request[ip] = rate.NewLimiter(rate.Limit(l.rpm/60), l.rpm) l.size[ip] = rate.NewLimiter(rate.Limit(l.spm/60), l.burst) } return !l.size[ip].AllowN(time.Now(), int(size)) }