Merge pull request #30460 from overleaf/ii-await-csrf

[web] Promisify Csrf

GitOrigin-RevId: 00e1d8e3d79c58e4cb614574415cba3a1b21f1f2
This commit is contained in:
ilkin-overleaf
2026-01-08 12:03:16 +02:00
committed by Copybot
parent 5e49f421c2
commit e0483dd6c3
3 changed files with 82 additions and 79 deletions

View File

@@ -1,14 +1,54 @@
import csurf from 'csurf'
import { promisify } from 'node:util'
import Settings from '@overleaf/settings'
import logger from '@overleaf/logger'
import { callbackify } from '@overleaf/promise-utils'
const csrf = csurf()
function blockCrossOriginRequests() {
return function (req, res, next) {
const { origin } = req.headers
// NOTE: Only cross-origin requests must have an origin header set.
if (origin && !Settings.allowedOrigins.includes(origin)) {
logger.warn({ req }, 'blocking cross-origin request')
return res.sendStatus(403)
}
next()
}
}
function validateRequest(req) {
// run a dummy csrf check to see if it returns an error
return new Promise((resolve, reject) => {
csrf(req, null, err => {
if (err) {
reject(err)
} else {
resolve()
}
})
})
}
async function validateToken(token, session) {
if (token == null) {
throw new Error('missing token')
}
// run a dummy csrf check to see if it returns an error
// use this to simulate a csrf check regardless of req method, headers &c.
const req = {
body: {
_csrf: token,
},
headers: {},
method: 'POST',
session,
}
await validateRequest(req)
}
// Wrapper for `csurf` middleware that provides a list of routes that can be excluded from csrf checks.
//
// Include with `Csrf = require('./Csrf')`
//
// Add the middleware to the router with:
// myRouter.csrf = new Csrf()
// myRouter.use webRouter.csrf.middleware
@@ -16,26 +56,14 @@ const csrf = csurf()
// myRouter.csrf.disableDefaultCsrfProtection "/path" "METHOD"
//
// To validate the csrf token in a request to ensure that it's valid, you can use `validateRequest`, which takes a
// request object and calls a callback with an error if invalid.
// request object rejects with an error if invalid.
class Csrf {
export class Csrf {
constructor() {
this.middleware = this.middleware.bind(this)
this.excluded_routes = {}
}
static blockCrossOriginRequests() {
return function (req, res, next) {
const { origin } = req.headers
// NOTE: Only cross-origin requests must have an origin header set.
if (origin && !Settings.allowedOrigins.includes(origin)) {
logger.warn({ req }, 'blocking cross-origin request')
return res.sendStatus(403)
}
next()
}
}
disableDefaultCsrfProtection(route, method) {
if (!this.excluded_routes[route]) {
this.excluded_routes[route] = {}
@@ -62,36 +90,14 @@ class Csrf {
csrf(req, res, next)
}
}
static validateRequest(req, cb) {
// run a dummy csrf check to see if it returns an error
if (cb == null) {
cb = function (valid) {}
}
csrf(req, null, err => cb(err))
}
static validateToken(token, session, cb) {
if (token == null) {
return cb(new Error('missing token'))
}
// run a dummy csrf check to see if it returns an error
// use this to simulate a csrf check regardless of req method, headers &c.
const req = {
body: {
_csrf: token,
},
headers: {},
method: 'POST',
session,
}
Csrf.validateRequest(req, cb)
}
}
Csrf.promises = {
validateRequest: promisify(Csrf.validateRequest),
validateToken: promisify(Csrf.validateToken),
export default {
blockCrossOriginRequests,
validateRequest: callbackify(validateRequest),
validateToken: callbackify(validateToken),
promises: {
validateRequest,
validateToken,
},
}
export default Csrf

View File

@@ -6,7 +6,7 @@ import csp, { removeCSPHeaders } from './CSP.mjs'
import Router from '../router.mjs'
import helmet from 'helmet'
import UserSessionsRedis from '../Features/User/UserSessionsRedis.mjs'
import Csrf from './Csrf.mjs'
import Csrf, { Csrf as CsrfClass } from './Csrf.mjs'
import HttpPermissionsPolicyMiddleware from './HttpPermissionsPolicy.mjs'
import SessionAutostartMiddleware from './SessionAutostartMiddleware.mjs'
import AnalyticsManager from '../Features/Analytics/AnalyticsManager.mjs'
@@ -232,7 +232,7 @@ Modules.hooks.fire('passportSetup', passport, err => {
await Modules.applyNonCsrfRouter(webRouter, privateApiRouter, publicApiRouter)
webRouter.csrf = new Csrf()
webRouter.csrf = new CsrfClass()
webRouter.use(webRouter.csrf.middleware)
webRouter.use(translations.i18nMiddleware)
webRouter.use(translations.setLangBasedOnDomainMiddleware)

View File

@@ -13,8 +13,10 @@ describe('Csrf', function () {
default: sinon.stub().returns(ctx.csurf_csrf),
}))
ctx.Csrf = (await import(modulePath)).default
ctx.csrf = new ctx.Csrf()
const module = await import(modulePath)
ctx.Csrf = module.default
ctx.CsrfClass = module.Csrf
ctx.csrf = new ctx.CsrfClass()
ctx.next = sinon.stub()
ctx.path = '/foo/bar'
ctx.req = {
@@ -87,7 +89,7 @@ describe('Csrf', function () {
ctx.csurf_csrf.callsArgWith(2, err)
const csrf = new ctx.Csrf()
const csrf = new ctx.CsrfClass()
csrf.disableDefaultCsrfProtection(ctx.path, 'POST')
csrf.middleware(ctx.req, ctx.res, ctx.next)
expect(ctx.next.calledWith(err)).to.equal(true)
@@ -98,55 +100,51 @@ describe('Csrf', function () {
describe('validateRequest', function () {
describe('when the request is invalid', function () {
it('calls the callback with error', function (ctx) {
ctx.cb = sinon.stub()
ctx.Csrf.validateRequest(ctx.req, ctx.cb)
expect(ctx.cb.calledWith(ctx.err)).to.equal(true)
it('rejects the promise', async function (ctx) {
await expect(
ctx.Csrf.promises.validateRequest(ctx.req)
).to.be.rejectedWith(ctx.err)
})
})
describe('when the request is valid', function () {
it('calls the callback without an error', async function (ctx) {
it('resolves the promise', async function (ctx) {
vi.resetModules()
vi.doMock('csurf', () => ({
default: (ctx.csurf = sinon
.stub()
.returns((ctx.csurf_csrf = sinon.stub().callsArg(2)))),
default: sinon.stub().returns(sinon.stub().callsArg(2)),
}))
ctx.Csrf = (await import(modulePath)).default
ctx.cb = sinon.stub()
ctx.Csrf.validateRequest(ctx.req, ctx.cb)
expect(ctx.cb.calledWith()).to.equal(true)
await expect(ctx.Csrf.promises.validateRequest(ctx.req)).to.eventually
.be.fulfilled
})
})
})
describe('validateToken', function () {
describe('when the request is invalid', function () {
it('calls the callback with `false`', function (ctx) {
ctx.cb = sinon.stub()
ctx.Csrf.validateToken('token', {}, ctx.cb)
expect(ctx.cb.calledWith(ctx.err)).to.equal(true)
it('rejects the promise', async function (ctx) {
await expect(
ctx.Csrf.promises.validateToken('token', {})
).to.be.rejectedWith(ctx.err)
})
})
describe('when the request is valid', function () {
it('calls the callback with `true`', async function (ctx) {
it('resolves the promise', async function (ctx) {
vi.resetModules()
vi.doMock('csurf', () => ({
default: (ctx.csurf = sinon
.stub()
.returns((ctx.csurf_csrf = sinon.stub().callsArg(2)))),
default: sinon.stub().returns(sinon.stub().callsArg(2)),
}))
ctx.Csrf = (await import(modulePath)).default
ctx.cb = sinon.stub()
ctx.Csrf.validateToken('goodtoken', {}, ctx.cb)
expect(ctx.cb.calledWith()).to.equal(true)
await expect(ctx.Csrf.promises.validateToken('goodtoken', {})).to
.eventually.be.fulfilled
})
})
describe('when there is no token', function () {
it('calls the callback with an error', async function (ctx) {
it('throws an error', async function (ctx) {
vi.doMock('csurf', () => ({
default: (ctx.csurf = sinon
.stub()
@@ -154,10 +152,9 @@ describe('Csrf', function () {
}))
ctx.Csrf = (await import(modulePath)).default
ctx.cb = sinon.stub()
ctx.Csrf.validateToken(null, {}, error => {
expect(error).to.exist
})
await expect(
ctx.Csrf.promises.validateToken(null, {})
).to.be.rejectedWith('missing token')
})
})
})