mirror of
https://github.com/yu-i-i/overleaf-cep.git
synced 2026-05-23 17:19:37 +02:00
Merge pull request #30460 from overleaf/ii-await-csrf
[web] Promisify Csrf GitOrigin-RevId: 00e1d8e3d79c58e4cb614574415cba3a1b21f1f2
This commit is contained in:
@@ -1,14 +1,54 @@
|
||||
import csurf from 'csurf'
|
||||
import { promisify } from 'node:util'
|
||||
import Settings from '@overleaf/settings'
|
||||
import logger from '@overleaf/logger'
|
||||
import { callbackify } from '@overleaf/promise-utils'
|
||||
|
||||
const csrf = csurf()
|
||||
|
||||
function blockCrossOriginRequests() {
|
||||
return function (req, res, next) {
|
||||
const { origin } = req.headers
|
||||
// NOTE: Only cross-origin requests must have an origin header set.
|
||||
if (origin && !Settings.allowedOrigins.includes(origin)) {
|
||||
logger.warn({ req }, 'blocking cross-origin request')
|
||||
return res.sendStatus(403)
|
||||
}
|
||||
next()
|
||||
}
|
||||
}
|
||||
|
||||
function validateRequest(req) {
|
||||
// run a dummy csrf check to see if it returns an error
|
||||
return new Promise((resolve, reject) => {
|
||||
csrf(req, null, err => {
|
||||
if (err) {
|
||||
reject(err)
|
||||
} else {
|
||||
resolve()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async function validateToken(token, session) {
|
||||
if (token == null) {
|
||||
throw new Error('missing token')
|
||||
}
|
||||
// run a dummy csrf check to see if it returns an error
|
||||
// use this to simulate a csrf check regardless of req method, headers &c.
|
||||
const req = {
|
||||
body: {
|
||||
_csrf: token,
|
||||
},
|
||||
headers: {},
|
||||
method: 'POST',
|
||||
session,
|
||||
}
|
||||
await validateRequest(req)
|
||||
}
|
||||
|
||||
// Wrapper for `csurf` middleware that provides a list of routes that can be excluded from csrf checks.
|
||||
//
|
||||
// Include with `Csrf = require('./Csrf')`
|
||||
//
|
||||
// Add the middleware to the router with:
|
||||
// myRouter.csrf = new Csrf()
|
||||
// myRouter.use webRouter.csrf.middleware
|
||||
@@ -16,26 +56,14 @@ const csrf = csurf()
|
||||
// myRouter.csrf.disableDefaultCsrfProtection "/path" "METHOD"
|
||||
//
|
||||
// To validate the csrf token in a request to ensure that it's valid, you can use `validateRequest`, which takes a
|
||||
// request object and calls a callback with an error if invalid.
|
||||
// request object rejects with an error if invalid.
|
||||
|
||||
class Csrf {
|
||||
export class Csrf {
|
||||
constructor() {
|
||||
this.middleware = this.middleware.bind(this)
|
||||
this.excluded_routes = {}
|
||||
}
|
||||
|
||||
static blockCrossOriginRequests() {
|
||||
return function (req, res, next) {
|
||||
const { origin } = req.headers
|
||||
// NOTE: Only cross-origin requests must have an origin header set.
|
||||
if (origin && !Settings.allowedOrigins.includes(origin)) {
|
||||
logger.warn({ req }, 'blocking cross-origin request')
|
||||
return res.sendStatus(403)
|
||||
}
|
||||
next()
|
||||
}
|
||||
}
|
||||
|
||||
disableDefaultCsrfProtection(route, method) {
|
||||
if (!this.excluded_routes[route]) {
|
||||
this.excluded_routes[route] = {}
|
||||
@@ -62,36 +90,14 @@ class Csrf {
|
||||
csrf(req, res, next)
|
||||
}
|
||||
}
|
||||
|
||||
static validateRequest(req, cb) {
|
||||
// run a dummy csrf check to see if it returns an error
|
||||
if (cb == null) {
|
||||
cb = function (valid) {}
|
||||
}
|
||||
csrf(req, null, err => cb(err))
|
||||
}
|
||||
|
||||
static validateToken(token, session, cb) {
|
||||
if (token == null) {
|
||||
return cb(new Error('missing token'))
|
||||
}
|
||||
// run a dummy csrf check to see if it returns an error
|
||||
// use this to simulate a csrf check regardless of req method, headers &c.
|
||||
const req = {
|
||||
body: {
|
||||
_csrf: token,
|
||||
},
|
||||
headers: {},
|
||||
method: 'POST',
|
||||
session,
|
||||
}
|
||||
Csrf.validateRequest(req, cb)
|
||||
}
|
||||
}
|
||||
|
||||
Csrf.promises = {
|
||||
validateRequest: promisify(Csrf.validateRequest),
|
||||
validateToken: promisify(Csrf.validateToken),
|
||||
export default {
|
||||
blockCrossOriginRequests,
|
||||
validateRequest: callbackify(validateRequest),
|
||||
validateToken: callbackify(validateToken),
|
||||
promises: {
|
||||
validateRequest,
|
||||
validateToken,
|
||||
},
|
||||
}
|
||||
|
||||
export default Csrf
|
||||
|
||||
@@ -6,7 +6,7 @@ import csp, { removeCSPHeaders } from './CSP.mjs'
|
||||
import Router from '../router.mjs'
|
||||
import helmet from 'helmet'
|
||||
import UserSessionsRedis from '../Features/User/UserSessionsRedis.mjs'
|
||||
import Csrf from './Csrf.mjs'
|
||||
import Csrf, { Csrf as CsrfClass } from './Csrf.mjs'
|
||||
import HttpPermissionsPolicyMiddleware from './HttpPermissionsPolicy.mjs'
|
||||
import SessionAutostartMiddleware from './SessionAutostartMiddleware.mjs'
|
||||
import AnalyticsManager from '../Features/Analytics/AnalyticsManager.mjs'
|
||||
@@ -232,7 +232,7 @@ Modules.hooks.fire('passportSetup', passport, err => {
|
||||
|
||||
await Modules.applyNonCsrfRouter(webRouter, privateApiRouter, publicApiRouter)
|
||||
|
||||
webRouter.csrf = new Csrf()
|
||||
webRouter.csrf = new CsrfClass()
|
||||
webRouter.use(webRouter.csrf.middleware)
|
||||
webRouter.use(translations.i18nMiddleware)
|
||||
webRouter.use(translations.setLangBasedOnDomainMiddleware)
|
||||
|
||||
@@ -13,8 +13,10 @@ describe('Csrf', function () {
|
||||
default: sinon.stub().returns(ctx.csurf_csrf),
|
||||
}))
|
||||
|
||||
ctx.Csrf = (await import(modulePath)).default
|
||||
ctx.csrf = new ctx.Csrf()
|
||||
const module = await import(modulePath)
|
||||
ctx.Csrf = module.default
|
||||
ctx.CsrfClass = module.Csrf
|
||||
ctx.csrf = new ctx.CsrfClass()
|
||||
ctx.next = sinon.stub()
|
||||
ctx.path = '/foo/bar'
|
||||
ctx.req = {
|
||||
@@ -87,7 +89,7 @@ describe('Csrf', function () {
|
||||
|
||||
ctx.csurf_csrf.callsArgWith(2, err)
|
||||
|
||||
const csrf = new ctx.Csrf()
|
||||
const csrf = new ctx.CsrfClass()
|
||||
csrf.disableDefaultCsrfProtection(ctx.path, 'POST')
|
||||
csrf.middleware(ctx.req, ctx.res, ctx.next)
|
||||
expect(ctx.next.calledWith(err)).to.equal(true)
|
||||
@@ -98,55 +100,51 @@ describe('Csrf', function () {
|
||||
|
||||
describe('validateRequest', function () {
|
||||
describe('when the request is invalid', function () {
|
||||
it('calls the callback with error', function (ctx) {
|
||||
ctx.cb = sinon.stub()
|
||||
ctx.Csrf.validateRequest(ctx.req, ctx.cb)
|
||||
expect(ctx.cb.calledWith(ctx.err)).to.equal(true)
|
||||
it('rejects the promise', async function (ctx) {
|
||||
await expect(
|
||||
ctx.Csrf.promises.validateRequest(ctx.req)
|
||||
).to.be.rejectedWith(ctx.err)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when the request is valid', function () {
|
||||
it('calls the callback without an error', async function (ctx) {
|
||||
it('resolves the promise', async function (ctx) {
|
||||
vi.resetModules()
|
||||
vi.doMock('csurf', () => ({
|
||||
default: (ctx.csurf = sinon
|
||||
.stub()
|
||||
.returns((ctx.csurf_csrf = sinon.stub().callsArg(2)))),
|
||||
default: sinon.stub().returns(sinon.stub().callsArg(2)),
|
||||
}))
|
||||
|
||||
ctx.Csrf = (await import(modulePath)).default
|
||||
ctx.cb = sinon.stub()
|
||||
ctx.Csrf.validateRequest(ctx.req, ctx.cb)
|
||||
expect(ctx.cb.calledWith()).to.equal(true)
|
||||
await expect(ctx.Csrf.promises.validateRequest(ctx.req)).to.eventually
|
||||
.be.fulfilled
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('validateToken', function () {
|
||||
describe('when the request is invalid', function () {
|
||||
it('calls the callback with `false`', function (ctx) {
|
||||
ctx.cb = sinon.stub()
|
||||
ctx.Csrf.validateToken('token', {}, ctx.cb)
|
||||
expect(ctx.cb.calledWith(ctx.err)).to.equal(true)
|
||||
it('rejects the promise', async function (ctx) {
|
||||
await expect(
|
||||
ctx.Csrf.promises.validateToken('token', {})
|
||||
).to.be.rejectedWith(ctx.err)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when the request is valid', function () {
|
||||
it('calls the callback with `true`', async function (ctx) {
|
||||
it('resolves the promise', async function (ctx) {
|
||||
vi.resetModules()
|
||||
vi.doMock('csurf', () => ({
|
||||
default: (ctx.csurf = sinon
|
||||
.stub()
|
||||
.returns((ctx.csurf_csrf = sinon.stub().callsArg(2)))),
|
||||
default: sinon.stub().returns(sinon.stub().callsArg(2)),
|
||||
}))
|
||||
|
||||
ctx.Csrf = (await import(modulePath)).default
|
||||
ctx.cb = sinon.stub()
|
||||
ctx.Csrf.validateToken('goodtoken', {}, ctx.cb)
|
||||
expect(ctx.cb.calledWith()).to.equal(true)
|
||||
await expect(ctx.Csrf.promises.validateToken('goodtoken', {})).to
|
||||
.eventually.be.fulfilled
|
||||
})
|
||||
})
|
||||
|
||||
describe('when there is no token', function () {
|
||||
it('calls the callback with an error', async function (ctx) {
|
||||
it('throws an error', async function (ctx) {
|
||||
vi.doMock('csurf', () => ({
|
||||
default: (ctx.csurf = sinon
|
||||
.stub()
|
||||
@@ -154,10 +152,9 @@ describe('Csrf', function () {
|
||||
}))
|
||||
|
||||
ctx.Csrf = (await import(modulePath)).default
|
||||
ctx.cb = sinon.stub()
|
||||
ctx.Csrf.validateToken(null, {}, error => {
|
||||
expect(error).to.exist
|
||||
})
|
||||
await expect(
|
||||
ctx.Csrf.promises.validateToken(null, {})
|
||||
).to.be.rejectedWith('missing token')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user