From f97a543d416d5e98b69a8a63dc341ef8ff019807 Mon Sep 17 00:00:00 2001 From: Eric Mc Sween Date: Tue, 17 Jan 2023 07:32:51 -0500 Subject: [PATCH] Merge pull request #11255 from overleaf/em-rate-limiter Introduce rate-limiter-flexible GitOrigin-RevId: c787397e276fb81015c7d045d191f2ad81ef542d --- package-lock.json | 12 ++ .../Security/RateLimiterMiddleware.js | 39 ++++ .../web/app/src/infrastructure/RateLimiter.js | 55 +++++- services/web/app/src/router.js | 13 +- services/web/package.json | 1 + .../Security/RateLimiterMiddlewareTests.js | 178 +++++++++++++----- 6 files changed, 240 insertions(+), 58 deletions(-) diff --git a/package-lock.json b/package-lock.json index cc3e8e8e0e..b4165990f2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -26737,6 +26737,11 @@ "node": ">= 0.6" } }, + "node_modules/rate-limiter-flexible": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/rate-limiter-flexible/-/rate-limiter-flexible-2.4.1.tgz", + "integrity": "sha512-dgH4T44TzKVO9CLArNto62hJOwlWJMLUjVVr/ii0uUzZXEXthDNr7/yefW5z/1vvHAfycc1tnuiYyNJ8CTRB3g==" + }, "node_modules/raw-body": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.1.tgz", @@ -36324,6 +36329,7 @@ "pug": "^3.0.1", "pug-runtime": "^3.0.1", "qrcode": "^1.4.4", + "rate-limiter-flexible": "^2.4.1", "react": "^17.0.2", "react-bootstrap": "^0.33.1", "react-chartjs-2": "^5.0.1", @@ -47171,6 +47177,7 @@ "pug": "^3.0.1", "pug-runtime": "^3.0.1", "qrcode": "^1.4.4", + "rate-limiter-flexible": "^2.4.1", "react": "^17.0.2", "react-bootstrap": "^0.33.1", "react-chartjs-2": "^5.0.1", @@ -64660,6 +64667,11 @@ "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz", "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==" }, + "rate-limiter-flexible": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/rate-limiter-flexible/-/rate-limiter-flexible-2.4.1.tgz", + "integrity": "sha512-dgH4T44TzKVO9CLArNto62hJOwlWJMLUjVVr/ii0uUzZXEXthDNr7/yefW5z/1vvHAfycc1tnuiYyNJ8CTRB3g==" + }, "raw-body": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.1.tgz", diff --git a/services/web/app/src/Features/Security/RateLimiterMiddleware.js b/services/web/app/src/Features/Security/RateLimiterMiddleware.js index c32c7e9d81..7095b96857 100644 --- a/services/web/app/src/Features/Security/RateLimiterMiddleware.js +++ b/services/web/app/src/Features/Security/RateLimiterMiddleware.js @@ -59,6 +59,44 @@ function rateLimit(opts) { } } +function rateLimitV2(rateLimiter, opts = {}) { + const getUserId = + opts.getUserId || (req => SessionManager.getLoggedInUserId(req.session)) + return function (req, res, next) { + const userId = getUserId(req) || req.ip + if ( + settings.smokeTest && + settings.smokeTest.userId && + settings.smokeTest.userId.toString() === userId.toString() + ) { + // ignore smoke test user + return next() + } + + let key + if (opts.ipOnly) { + key = req.ip + } else { + const params = (opts.params || []).map(p => req.params[p]) + params.push(userId) + key = params.join(':') + } + + rateLimiter + .consume(key) + .then(() => next()) + .catch(err => { + if (err instanceof Error) { + next(err) + } else { + res.status(429) // Too many requests + res.write('Rate limit reached, please try again later') + res.end() + } + }) + } +} + function loginRateLimit(req, res, next) { const { email } = req.body if (!email) { @@ -81,6 +119,7 @@ function loginRateLimit(req, res, next) { const RateLimiterMiddleware = { rateLimit, + rateLimitV2, loginRateLimit, } diff --git a/services/web/app/src/infrastructure/RateLimiter.js b/services/web/app/src/infrastructure/RateLimiter.js index 6691edb820..6bac414898 100644 --- a/services/web/app/src/infrastructure/RateLimiter.js +++ b/services/web/app/src/infrastructure/RateLimiter.js @@ -1,8 +1,10 @@ const settings = require('@overleaf/settings') const Metrics = require('@overleaf/metrics') +const logger = require('@overleaf/logger') const RedisWrapper = require('./RedisWrapper') const rclient = RedisWrapper.client('ratelimiter') -const { RedisRateLimiter } = require('rolling-rate-limiter') +const RollingRateLimiter = require('rolling-rate-limiter') +const RateLimiterFlexible = require('rate-limiter-flexible') const { callbackify } = require('util') async function addCount(opts) { @@ -11,7 +13,7 @@ async function addCount(opts) { } const namespace = `RateLimit:${opts.endpointName}:` const k = `{${opts.subjectName}}` - const limiter = new RedisRateLimiter({ + const limiter = new RollingRateLimiter.RedisRateLimiter({ client: rclient, namespace, interval: opts.timeInterval * 1000, @@ -32,9 +34,58 @@ async function clearRateLimit(endpointName, subject) { await rclient.del(keyName) } +/** + * Wrapper over the RateLimiterRedis class + */ +class RateLimiter { + /** + * Create a rate limiter. + * + * @param name {string} The name that identifies this rate limiter. Different + * rate limiters must have different names. + * @param opts {object} Options to pass to RateLimiterRedis + * + * Some useful options: + * + * points - number of points that can be consumed over the given duration + * (default: 4) + * duration - duration of the fixed window in seconds (default: 1) + * blockDuration - additional seconds to block after all points are consumed + * (default: 0) + */ + constructor(name, opts = {}) { + this.name = name + this._rateLimiter = new RateLimiterFlexible.RateLimiterRedis({ + ...opts, + keyPrefix: `rate-limit:${name}`, + storeClient: rclient, + }) + } + + async consume(key, points = 1, options = {}) { + try { + const res = await this._rateLimiter.consume(key, points, options) + return res + } catch (err) { + if (err instanceof Error) { + throw err + } else { + // Only log the first time we exceed the rate limit for a given key and + // duration + if (err.consumedPoints === this._rateLimiter.points + points) { + logger.warn({ path: this.name, key }, 'rate limit exceeded') + } + Metrics.inc('rate-limit-hit', 1, { path: this.name }) + throw err + } + } + } +} + module.exports = { addCount: callbackify(addCount), clearRateLimit: callbackify(clearRateLimit), + RateLimiter, promises: { addCount, clearRateLimit, diff --git a/services/web/app/src/router.js b/services/web/app/src/router.js index f95bb526f9..6cd3e0b4e8 100644 --- a/services/web/app/src/router.js +++ b/services/web/app/src/router.js @@ -35,6 +35,7 @@ const PasswordResetRouter = require('./Features/PasswordReset/PasswordResetRoute const StaticPagesRouter = require('./Features/StaticPages/StaticPagesRouter') const ChatController = require('./Features/Chat/ChatController') const Modules = require('./infrastructure/Modules') +const { RateLimiter } = require('./infrastructure/RateLimiter') const RateLimiterMiddleware = require('./Features/Security/RateLimiterMiddleware') const InactiveProjectController = require('./Features/InactiveData/InactiveProjectController') const ContactRouter = require('./Features/Contacts/ContactRouter') @@ -68,6 +69,13 @@ const PublicAccessLevels = require('./Features/Authorization/PublicAccessLevels' module.exports = { initialize } +const rateLimiters = { + zipDownload: new RateLimiter('zip-download', { + points: 10, + duration: 60, + }), +} + function initialize(webRouter, privateApiRouter, publicApiRouter) { webRouter.use(unsupportedBrowserMiddleware) @@ -721,11 +729,8 @@ function initialize(webRouter, privateApiRouter, publicApiRouter) { webRouter.get( '/Project/:Project_id/download/zip', - RateLimiterMiddleware.rateLimit({ - endpointName: 'zip-download', + RateLimiterMiddleware.rateLimitV2(rateLimiters.zipDownload, { params: ['Project_id'], - maxRequests: 10, - timeInterval: 60, }), AuthorizationMiddleware.ensureUserCanReadProject, ProjectDownloadsController.downloadProject diff --git a/services/web/package.json b/services/web/package.json index 9326af6917..f34b07a8bc 100644 --- a/services/web/package.json +++ b/services/web/package.json @@ -212,6 +212,7 @@ "pug": "^3.0.1", "pug-runtime": "^3.0.1", "qrcode": "^1.4.4", + "rate-limiter-flexible": "^2.4.1", "react": "^17.0.2", "react-bootstrap": "^0.33.1", "react-chartjs-2": "^5.0.1", diff --git a/services/web/test/unit/src/Security/RateLimiterMiddlewareTests.js b/services/web/test/unit/src/Security/RateLimiterMiddlewareTests.js index 7816cbb305..75f358353f 100644 --- a/services/web/test/unit/src/Security/RateLimiterMiddlewareTests.js +++ b/services/web/test/unit/src/Security/RateLimiterMiddlewareTests.js @@ -1,16 +1,3 @@ -/* eslint-disable - max-len, - no-return-assign, -*/ -// TODO: This file was created by bulk-decaffeinate. -// Fix any style issues and re-enable lint. -/* - * decaffeinate suggestions: - * DS102: Remove unnecessary code created because of implicit returns - * DS103: Rewrite code to no longer use __guard__ - * DS207: Consider shorter variations of null checks - * Full docs: https://github.com/decaffeinate/decaffeinate/blob/master/docs/suggestions.md - */ const SandboxedModule = require('sandboxed-module') const sinon = require('sinon') const modulePath = require('path').join( @@ -21,20 +8,15 @@ const modulePath = require('path').join( describe('RateLimiterMiddleware', function () { beforeEach(function () { this.SessionManager = { - getLoggedInUserId: () => { - return __guard__( - __guard__( - this.req != null ? this.req.session : undefined, - x1 => x1.user - ), - x => x._id - ) - }, + getLoggedInUserId: () => this.req.session?.user?._id, + } + this.RateLimiter = { + addCount: sinon.stub().yields(null, true), } this.RateLimiterMiddleware = SandboxedModule.require(modulePath, { requires: { '@overleaf/settings': (this.settings = {}), - '../../infrastructure/RateLimiter': (this.RateLimiter = {}), + '../../infrastructure/RateLimiter': this.RateLimiter, './LoginRateLimiter': {}, '../Authentication/SessionManager': this.SessionManager, }, @@ -45,32 +27,31 @@ describe('RateLimiterMiddleware', function () { write: sinon.stub(), end: sinon.stub(), } - return (this.next = sinon.stub()) + this.next = sinon.stub() }) describe('rateLimit', function () { beforeEach(function () { - this.rateLimiter = this.RateLimiterMiddleware.rateLimit({ + this.middleware = this.RateLimiterMiddleware.rateLimit({ endpointName: 'test-endpoint', params: ['project_id', 'doc_id'], timeInterval: 42, maxRequests: 12, }) - return (this.req.params = { + this.req.params = { project_id: (this.project_id = 'project-id'), doc_id: (this.doc_id = 'doc-id'), - }) + } }) describe('when there is no session', function () { beforeEach(function () { - this.RateLimiter.addCount = sinon.stub().callsArgWith(1, null, true) this.req.ip = this.ip = '1.2.3.4' - return this.rateLimiter(this.req, this.res, this.next) + this.middleware(this.req, this.res, this.next) }) it('should call the rate limiter backend with the ip address', function () { - return this.RateLimiter.addCount + this.RateLimiter.addCount .calledWith({ endpointName: 'test-endpoint', timeInterval: 42, @@ -91,8 +72,7 @@ describe('RateLimiterMiddleware', function () { }, } this.settings.smokeTest = { userId: this.user_id } - this.RateLimiter.addCount = sinon.stub().callsArgWith(1, null, true) - return this.rateLimiter(this.req, this.res, this.next) + this.middleware(this.req, this.res, this.next) }) it('should not call the rate limiter backend with the user_id', function () { @@ -108,7 +88,7 @@ describe('RateLimiterMiddleware', function () { }) it('should pass on to next()', function () { - return this.next.called.should.equal(true) + this.next.called.should.equal(true) }) }) @@ -119,12 +99,11 @@ describe('RateLimiterMiddleware', function () { _id: (this.user_id = 'user-id'), }, } - this.RateLimiter.addCount = sinon.stub().callsArgWith(1, null, true) - return this.rateLimiter(this.req, this.res, this.next) + this.middleware(this.req, this.res, this.next) }) it('should call the rate limiter backend with the user_id', function () { - return this.RateLimiter.addCount + this.RateLimiter.addCount .calledWith({ endpointName: 'test-endpoint', timeInterval: 42, @@ -135,19 +114,18 @@ describe('RateLimiterMiddleware', function () { }) it('should pass on to next()', function () { - return this.next.called.should.equal(true) + this.next.called.should.equal(true) }) }) describe('when under the rate limit with anonymous user', function () { beforeEach(function () { this.req.ip = this.ip = '1.2.3.4' - this.RateLimiter.addCount = sinon.stub().callsArgWith(1, null, true) - return this.rateLimiter(this.req, this.res, this.next) + this.middleware(this.req, this.res, this.next) }) it('should call the rate limiter backend with the ip address', function () { - return this.RateLimiter.addCount + this.RateLimiter.addCount .calledWith({ endpointName: 'test-endpoint', timeInterval: 42, @@ -158,7 +136,7 @@ describe('RateLimiterMiddleware', function () { }) it('should pass on to next()', function () { - return this.next.called.should.equal(true) + this.next.called.should.equal(true) }) }) @@ -169,21 +147,21 @@ describe('RateLimiterMiddleware', function () { _id: (this.user_id = 'user-id'), }, } - this.RateLimiter.addCount = sinon.stub().callsArgWith(1, null, false) - return this.rateLimiter(this.req, this.res, this.next) + this.RateLimiter.addCount.yields(null, false) + this.middleware(this.req, this.res, this.next) }) it('should return a 429', function () { this.res.status.calledWith(429).should.equal(true) - return this.res.end.called.should.equal(true) + this.res.end.called.should.equal(true) }) it('should not continue', function () { - return this.next.called.should.equal(false) + this.next.called.should.equal(false) }) it('should log a warning', function () { - return this.logger.warn + this.logger.warn .calledWith( { endpointName: 'test-endpoint', @@ -197,10 +175,106 @@ describe('RateLimiterMiddleware', function () { }) }) }) -}) -function __guard__(value, transform) { - return typeof value !== 'undefined' && value !== null - ? transform(value) - : undefined -} + describe('rateLimitV2', function () { + beforeEach(function () { + this.projectId = 'project-id' + this.docId = 'doc-id' + this.rateLimiter = { + consume: sinon.stub().resolves({ remainingPoints: 2 }), + } + this.middleware = this.RateLimiterMiddleware.rateLimitV2( + this.rateLimiter, + { params: ['projectId', 'docId'] } + ) + this.req.params = { projectId: this.projectId, docId: this.docId } + }) + + describe('when there is no session', function () { + beforeEach(function (done) { + this.req.ip = this.ip = '1.2.3.4' + this.middleware(this.req, this.res, () => { + done() + }) + }) + + it('should call the rate limiter with the ip address', function () { + this.rateLimiter.consume.should.have.been.calledWith( + `${this.projectId}:${this.docId}:${this.ip}` + ) + }) + }) + + describe('when smoke test user', function () { + beforeEach(function (done) { + this.userId = 'smoke-test-user-id' + this.req.session = { + user: { _id: this.userId }, + } + this.settings.smokeTest = { userId: this.userId } + this.middleware(this.req, this.res, () => { + done() + }) + }) + + it('should not call the rate limiter', function () { + this.rateLimiter.consume.should.not.have.been.called + }) + }) + + describe('when under the rate limit with logged in user', function () { + beforeEach(function (done) { + this.userId = 'user-id' + this.req.session = { + user: { _id: this.userId }, + } + this.middleware(this.req, this.res, () => { + done() + }) + }) + + it('should call the rate limiter backend with the userId', function () { + this.rateLimiter.consume.should.have.been.calledWith( + `${this.projectId}:${this.docId}:${this.userId}` + ) + }) + }) + + describe('when under the rate limit with anonymous user', function () { + beforeEach(function (done) { + this.req.ip = '1.2.3.4' + this.middleware(this.req, this.res, () => { + done() + }) + }) + + it('should call the rate limiter backend with the ip address', function () { + this.rateLimiter.consume.should.have.been.calledWith( + `${this.projectId}:${this.docId}:${this.req.ip}` + ) + }) + }) + + describe('when over the rate limit', function () { + beforeEach(function (done) { + this.userId = 'user-id' + this.req.session = { + user: { _id: this.userId }, + } + this.res.end.callsFake(() => { + done() + }) + this.rateLimiter.consume.rejects({ remainingPoints: 0 }) + this.middleware(this.req, this.res, this.next) + }) + + it('should return a 429', function () { + this.res.status.should.have.been.calledWith(429) + }) + + it('should not continue', function () { + this.next.should.not.have.been.called + }) + }) + }) +})