Browse Source

rate limiter added

master
Christian Müller 8 years ago
parent
commit
61df9cb09b
  1. 54
      rate_limit.go
  2. 1
      render.go
  3. 10
      server.go
  4. 4
      stats.go

54
rate_limit.go

@ -0,0 +1,54 @@
package main
import (
"sync"
"time"
"github.com/labstack/echo"
)
const (
rateLimit = 20 // times per rateLimitInterval
rateLimitInterval = 1 * time.Hour
)
var accesses = &sync.Map{}
type access struct {
count int
timestamp time.Time
}
func legitAccess(c echo.Context) bool {
ip := c.Request().RemoteAddr
aRaw, found := accesses.Load(ip)
var a *access
if found {
a, _ = aRaw.(*access)
} else {
a = &access{}
}
a.count++
a.timestamp = time.Now()
accesses.Store(ip, a)
return a.count < rateLimit
}
func cleanAccessRegistry(logger echo.Logger) {
for {
time.Sleep(rateLimitInterval)
t, e := 0, 0
accesses.Range(func(ip, aRaw interface{}) bool {
t++
a, _ := aRaw.(*access)
if a.timestamp.Add(rateLimitInterval).Before(time.Now()) {
accesses.Delete(ip)
e++
}
return true
})
if e > 0 {
logger.Infof("cleaned up %d/%d outdated accesses", e, t)
}
}
}

1
render.go

@ -19,6 +19,7 @@ var (
401: "Unauthorized", 401: "Unauthorized",
404: "Not found", 404: "Not found",
412: "Precondition failed", 412: "Precondition failed",
429: "Too many requests",
503: "Service unavailable", 503: "Service unavailable",
} }

10
server.go

@ -49,6 +49,9 @@ func main() {
} }
} }
go persistStats(e.Logger, db, stats)
go cleanAccessRegistry(e.Logger)
e.Renderer = &Template{templates: template.Must(template.ParseGlob("assets/templates/*.html"))} e.Renderer = &Template{templates: template.Must(template.ParseGlob("assets/templates/*.html"))}
e.File("/favicon.ico", "assets/public/favicon.ico") e.File("/favicon.ico", "assets/public/favicon.ico")
@ -57,8 +60,6 @@ func main() {
e.File("/index.html", "assets/public/index.html") e.File("/index.html", "assets/public/index.html")
e.File("/", "assets/public/index.html") e.File("/", "assets/public/index.html")
go persistStats(e.Logger, db, stats)
e.GET("/TOS.md", func(c echo.Context) error { e.GET("/TOS.md", func(c echo.Context) error {
n, code := md2html(c, "TOS") n, code := md2html(c, "TOS")
return c.Render(code, "Page", n) return c.Render(code, "Page", n)
@ -111,6 +112,11 @@ func main() {
e.POST("/note", func(c echo.Context) error { e.POST("/note", func(c echo.Context) error {
c.Logger().Debug("POST /note requested") c.Logger().Debug("POST /note requested")
if !legitAccess(c) {
code := http.StatusTooManyRequests
c.Logger().Errorf("rate limit exceeded for %s", c.Request().RemoteAddr)
return c.Render(code, "Note", errPage(code))
}
vals, err := c.FormParams() vals, err := c.FormParams()
if err != nil { if err != nil {
return err return err

4
stats.go

@ -30,6 +30,8 @@ func persistStats(logger echo.Logger, db *sql.DB, stats *sync.Map) {
return true return true
}) })
tx.Commit() tx.Commit()
logger.Infof("successfully persisted %d values", c) if c > 0 {
logger.Infof("successfully persisted %d values", c)
}
} }
} }

Loading…
Cancel
Save