From 61df9cb09b9a45828b986e37ef0844b08f66bea9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20M=C3=BCller?= Date: Tue, 19 Sep 2017 23:24:42 +0200 Subject: [PATCH] rate limiter added --- rate_limit.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ render.go | 1 + server.go | 10 ++++++++-- stats.go | 4 +++- 4 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 rate_limit.go diff --git a/rate_limit.go b/rate_limit.go new file mode 100644 index 0000000..3f416b7 --- /dev/null +++ b/rate_limit.go @@ -0,0 +1,54 @@ +package main + +import ( + "sync" + "time" + + "github.com/labstack/echo" +) + +const ( + rateLimit = 20 // times per rateLimitInterval + rateLimitInterval = 1 * time.Hour +) + +var accesses = &sync.Map{} + +type access struct { + count int + timestamp time.Time +} + +func legitAccess(c echo.Context) bool { + ip := c.Request().RemoteAddr + aRaw, found := accesses.Load(ip) + var a *access + if found { + a, _ = aRaw.(*access) + } else { + a = &access{} + } + a.count++ + a.timestamp = time.Now() + accesses.Store(ip, a) + return a.count < rateLimit +} + +func cleanAccessRegistry(logger echo.Logger) { + for { + time.Sleep(rateLimitInterval) + t, e := 0, 0 + accesses.Range(func(ip, aRaw interface{}) bool { + t++ + a, _ := aRaw.(*access) + if a.timestamp.Add(rateLimitInterval).Before(time.Now()) { + accesses.Delete(ip) + e++ + } + return true + }) + if e > 0 { + logger.Infof("cleaned up %d/%d outdated accesses", e, t) + } + } +} diff --git a/render.go b/render.go index 3299898..36ea6e9 100644 --- a/render.go +++ b/render.go @@ -19,6 +19,7 @@ var ( 401: "Unauthorized", 404: "Not found", 412: "Precondition failed", + 429: "Too many requests", 503: "Service unavailable", } diff --git a/server.go b/server.go index 4e70eb1..8dd3464 100644 --- a/server.go +++ b/server.go @@ -49,6 +49,9 @@ func main() { } } + go persistStats(e.Logger, db, stats) + go cleanAccessRegistry(e.Logger) + e.Renderer = &Template{templates: template.Must(template.ParseGlob("assets/templates/*.html"))} e.File("/favicon.ico", "assets/public/favicon.ico") @@ -57,8 +60,6 @@ func main() { e.File("/index.html", "assets/public/index.html") e.File("/", "assets/public/index.html") - go persistStats(e.Logger, db, stats) - e.GET("/TOS.md", func(c echo.Context) error { n, code := md2html(c, "TOS") return c.Render(code, "Page", n) @@ -111,6 +112,11 @@ func main() { e.POST("/note", func(c echo.Context) error { c.Logger().Debug("POST /note requested") + if !legitAccess(c) { + code := http.StatusTooManyRequests + c.Logger().Errorf("rate limit exceeded for %s", c.Request().RemoteAddr) + return c.Render(code, "Note", errPage(code)) + } vals, err := c.FormParams() if err != nil { return err diff --git a/stats.go b/stats.go index 5a8f2e0..7bdcf05 100644 --- a/stats.go +++ b/stats.go @@ -30,6 +30,8 @@ func persistStats(logger echo.Logger, db *sql.DB, stats *sync.Map) { return true }) tx.Commit() - logger.Infof("successfully persisted %d values", c) + if c > 0 { + logger.Infof("successfully persisted %d values", c) + } } }