diff --git a/services/web/app/src/infrastructure/Csrf.mjs b/services/web/app/src/infrastructure/Csrf.mjs index 2192db149f..fd46ab06ee 100644 --- a/services/web/app/src/infrastructure/Csrf.mjs +++ b/services/web/app/src/infrastructure/Csrf.mjs @@ -1,14 +1,54 @@ import csurf from 'csurf' -import { promisify } from 'node:util' import Settings from '@overleaf/settings' import logger from '@overleaf/logger' +import { callbackify } from '@overleaf/promise-utils' const csrf = csurf() +function blockCrossOriginRequests() { + return function (req, res, next) { + const { origin } = req.headers + // NOTE: Only cross-origin requests must have an origin header set. + if (origin && !Settings.allowedOrigins.includes(origin)) { + logger.warn({ req }, 'blocking cross-origin request') + return res.sendStatus(403) + } + next() + } +} + +function validateRequest(req) { + // run a dummy csrf check to see if it returns an error + return new Promise((resolve, reject) => { + csrf(req, null, err => { + if (err) { + reject(err) + } else { + resolve() + } + }) + }) +} + +async function validateToken(token, session) { + if (token == null) { + throw new Error('missing token') + } + // run a dummy csrf check to see if it returns an error + // use this to simulate a csrf check regardless of req method, headers &c. + const req = { + body: { + _csrf: token, + }, + headers: {}, + method: 'POST', + session, + } + await validateRequest(req) +} + // Wrapper for `csurf` middleware that provides a list of routes that can be excluded from csrf checks. // -// Include with `Csrf = require('./Csrf')` -// // Add the middleware to the router with: // myRouter.csrf = new Csrf() // myRouter.use webRouter.csrf.middleware @@ -16,26 +56,14 @@ const csrf = csurf() // myRouter.csrf.disableDefaultCsrfProtection "/path" "METHOD" // // To validate the csrf token in a request to ensure that it's valid, you can use `validateRequest`, which takes a -// request object and calls a callback with an error if invalid. +// request object rejects with an error if invalid. -class Csrf { +export class Csrf { constructor() { this.middleware = this.middleware.bind(this) this.excluded_routes = {} } - static blockCrossOriginRequests() { - return function (req, res, next) { - const { origin } = req.headers - // NOTE: Only cross-origin requests must have an origin header set. - if (origin && !Settings.allowedOrigins.includes(origin)) { - logger.warn({ req }, 'blocking cross-origin request') - return res.sendStatus(403) - } - next() - } - } - disableDefaultCsrfProtection(route, method) { if (!this.excluded_routes[route]) { this.excluded_routes[route] = {} @@ -62,36 +90,14 @@ class Csrf { csrf(req, res, next) } } - - static validateRequest(req, cb) { - // run a dummy csrf check to see if it returns an error - if (cb == null) { - cb = function (valid) {} - } - csrf(req, null, err => cb(err)) - } - - static validateToken(token, session, cb) { - if (token == null) { - return cb(new Error('missing token')) - } - // run a dummy csrf check to see if it returns an error - // use this to simulate a csrf check regardless of req method, headers &c. - const req = { - body: { - _csrf: token, - }, - headers: {}, - method: 'POST', - session, - } - Csrf.validateRequest(req, cb) - } } -Csrf.promises = { - validateRequest: promisify(Csrf.validateRequest), - validateToken: promisify(Csrf.validateToken), +export default { + blockCrossOriginRequests, + validateRequest: callbackify(validateRequest), + validateToken: callbackify(validateToken), + promises: { + validateRequest, + validateToken, + }, } - -export default Csrf diff --git a/services/web/app/src/infrastructure/Server.mjs b/services/web/app/src/infrastructure/Server.mjs index d973452bd2..2da657f4a3 100644 --- a/services/web/app/src/infrastructure/Server.mjs +++ b/services/web/app/src/infrastructure/Server.mjs @@ -6,7 +6,7 @@ import csp, { removeCSPHeaders } from './CSP.mjs' import Router from '../router.mjs' import helmet from 'helmet' import UserSessionsRedis from '../Features/User/UserSessionsRedis.mjs' -import Csrf from './Csrf.mjs' +import Csrf, { Csrf as CsrfClass } from './Csrf.mjs' import HttpPermissionsPolicyMiddleware from './HttpPermissionsPolicy.mjs' import SessionAutostartMiddleware from './SessionAutostartMiddleware.mjs' import AnalyticsManager from '../Features/Analytics/AnalyticsManager.mjs' @@ -232,7 +232,7 @@ Modules.hooks.fire('passportSetup', passport, err => { await Modules.applyNonCsrfRouter(webRouter, privateApiRouter, publicApiRouter) -webRouter.csrf = new Csrf() +webRouter.csrf = new CsrfClass() webRouter.use(webRouter.csrf.middleware) webRouter.use(translations.i18nMiddleware) webRouter.use(translations.setLangBasedOnDomainMiddleware) diff --git a/services/web/test/unit/src/infrastructure/Csrf.test.mjs b/services/web/test/unit/src/infrastructure/Csrf.test.mjs index 554170bd66..30e8157f17 100644 --- a/services/web/test/unit/src/infrastructure/Csrf.test.mjs +++ b/services/web/test/unit/src/infrastructure/Csrf.test.mjs @@ -13,8 +13,10 @@ describe('Csrf', function () { default: sinon.stub().returns(ctx.csurf_csrf), })) - ctx.Csrf = (await import(modulePath)).default - ctx.csrf = new ctx.Csrf() + const module = await import(modulePath) + ctx.Csrf = module.default + ctx.CsrfClass = module.Csrf + ctx.csrf = new ctx.CsrfClass() ctx.next = sinon.stub() ctx.path = '/foo/bar' ctx.req = { @@ -87,7 +89,7 @@ describe('Csrf', function () { ctx.csurf_csrf.callsArgWith(2, err) - const csrf = new ctx.Csrf() + const csrf = new ctx.CsrfClass() csrf.disableDefaultCsrfProtection(ctx.path, 'POST') csrf.middleware(ctx.req, ctx.res, ctx.next) expect(ctx.next.calledWith(err)).to.equal(true) @@ -98,55 +100,51 @@ describe('Csrf', function () { describe('validateRequest', function () { describe('when the request is invalid', function () { - it('calls the callback with error', function (ctx) { - ctx.cb = sinon.stub() - ctx.Csrf.validateRequest(ctx.req, ctx.cb) - expect(ctx.cb.calledWith(ctx.err)).to.equal(true) + it('rejects the promise', async function (ctx) { + await expect( + ctx.Csrf.promises.validateRequest(ctx.req) + ).to.be.rejectedWith(ctx.err) }) }) describe('when the request is valid', function () { - it('calls the callback without an error', async function (ctx) { + it('resolves the promise', async function (ctx) { + vi.resetModules() vi.doMock('csurf', () => ({ - default: (ctx.csurf = sinon - .stub() - .returns((ctx.csurf_csrf = sinon.stub().callsArg(2)))), + default: sinon.stub().returns(sinon.stub().callsArg(2)), })) ctx.Csrf = (await import(modulePath)).default - ctx.cb = sinon.stub() - ctx.Csrf.validateRequest(ctx.req, ctx.cb) - expect(ctx.cb.calledWith()).to.equal(true) + await expect(ctx.Csrf.promises.validateRequest(ctx.req)).to.eventually + .be.fulfilled }) }) }) describe('validateToken', function () { describe('when the request is invalid', function () { - it('calls the callback with `false`', function (ctx) { - ctx.cb = sinon.stub() - ctx.Csrf.validateToken('token', {}, ctx.cb) - expect(ctx.cb.calledWith(ctx.err)).to.equal(true) + it('rejects the promise', async function (ctx) { + await expect( + ctx.Csrf.promises.validateToken('token', {}) + ).to.be.rejectedWith(ctx.err) }) }) describe('when the request is valid', function () { - it('calls the callback with `true`', async function (ctx) { + it('resolves the promise', async function (ctx) { + vi.resetModules() vi.doMock('csurf', () => ({ - default: (ctx.csurf = sinon - .stub() - .returns((ctx.csurf_csrf = sinon.stub().callsArg(2)))), + default: sinon.stub().returns(sinon.stub().callsArg(2)), })) ctx.Csrf = (await import(modulePath)).default - ctx.cb = sinon.stub() - ctx.Csrf.validateToken('goodtoken', {}, ctx.cb) - expect(ctx.cb.calledWith()).to.equal(true) + await expect(ctx.Csrf.promises.validateToken('goodtoken', {})).to + .eventually.be.fulfilled }) }) describe('when there is no token', function () { - it('calls the callback with an error', async function (ctx) { + it('throws an error', async function (ctx) { vi.doMock('csurf', () => ({ default: (ctx.csurf = sinon .stub() @@ -154,10 +152,9 @@ describe('Csrf', function () { })) ctx.Csrf = (await import(modulePath)).default - ctx.cb = sinon.stub() - ctx.Csrf.validateToken(null, {}, error => { - expect(error).to.exist - }) + await expect( + ctx.Csrf.promises.validateToken(null, {}) + ).to.be.rejectedWith('missing token') }) }) })