From 501e11a42a663bac6c63f716b1203e57d27fe35d Mon Sep 17 00:00:00 2001 From: Jimmy Domagala-Tang Date: Tue, 3 Mar 2026 10:01:13 -0500 Subject: [PATCH] Move feature rate limiters to shared web folder (#31855) * feat: remove old assist split test * feat: moving featue rate limiters to main shared directory for use in multiple modules * feat: base workbench rate limiter on a token specific base class * feat: rename aiErrorAssistRateLimiter to AiFeatureUsageRateLimiter to better reflect its for our shared ai usage quota GitOrigin-RevId: 89464d115b5904f6274756a7169e2b35945e2fc9 --- .../eslint-plugin/prefer-kebab-url-ignore.js | 1 + .../Features/Project/ProjectController.mjs | 10 +- .../AiFeatureUsageRateLimiter.mjs | 59 +++ .../FeatureUsageRateLimiter.mjs | 8 +- .../rate-limiters/TokenUsageRateLimiter.mjs | 171 +++++++ .../rate-limiters/WorkbenchRateLimiter.mjs | 56 +++ .../web/app/src/models/UserFeatureUsage.mjs | 2 +- .../js/shared/context/editor-context.tsx | 8 +- .../src/Project/ProjectController.test.mjs | 13 + .../AiFeatureUsageRateLimiter.test.mjs | 182 ++++++++ ...eatureUsageRateLimiter.sequential.test.mjs | 2 +- .../WorkbenchRateLimiter.sequential.test.mjs | 416 ++++++++++++++++++ 12 files changed, 910 insertions(+), 18 deletions(-) create mode 100644 services/web/app/src/infrastructure/rate-limiters/AiFeatureUsageRateLimiter.mjs rename services/web/app/src/infrastructure/{ => rate-limiters}/FeatureUsageRateLimiter.mjs (95%) create mode 100644 services/web/app/src/infrastructure/rate-limiters/TokenUsageRateLimiter.mjs create mode 100644 services/web/app/src/infrastructure/rate-limiters/WorkbenchRateLimiter.mjs create mode 100644 services/web/test/unit/src/infrastructure/AiFeatureUsageRateLimiter.test.mjs create mode 100644 services/web/test/unit/src/infrastructure/WorkbenchRateLimiter.sequential.test.mjs diff --git a/libraries/eslint-plugin/prefer-kebab-url-ignore.js b/libraries/eslint-plugin/prefer-kebab-url-ignore.js index ccb122566f..92ee23de6b 100644 --- a/libraries/eslint-plugin/prefer-kebab-url-ignore.js +++ b/libraries/eslint-plugin/prefer-kebab-url-ignore.js @@ -49,6 +49,7 @@ const ignoreWords = { camel: new Set([ 'addWorkflowScope', 'aiErrorAssistant', + 'aiFeatureUsage', 'beginAuth', 'brandVariationId', 'closeEditor', diff --git a/services/web/app/src/Features/Project/ProjectController.mjs b/services/web/app/src/Features/Project/ProjectController.mjs index ec0ea863d5..4fb5a65177 100644 --- a/services/web/app/src/Features/Project/ProjectController.mjs +++ b/services/web/app/src/Features/Project/ProjectController.mjs @@ -52,6 +52,7 @@ import { isStandaloneAiAddOnPlanCode } from '../Subscription/AiHelper.mjs' import SubscriptionController from '../Subscription/SubscriptionController.mjs' import { formatCurrency } from '../../util/currency.js' import UserSettingsHelper from './UserSettingsHelper.mjs' +import AiFeatureUsageRateLimiter from '../../infrastructure/rate-limiters/AiFeatureUsageRateLimiter.mjs' const { isPaidSubscription } = SubscriptionHelper const { hasAdminAccess } = AdminAuthorizationHelper @@ -795,13 +796,8 @@ const _ProjectController = { let featureUsage = {} if (Features.hasFeature('saas')) { - const usagesLeft = await Modules.promises.hooks.fire( - 'remainingFeatureAllocation', - userId - ) - usagesLeft?.forEach(usage => { - featureUsage = { ...featureUsage, ...usage } - }) + featureUsage = + await AiFeatureUsageRateLimiter.getRemainingFeatureUses(userId) } await ProjectController._setWritefullTrialState( diff --git a/services/web/app/src/infrastructure/rate-limiters/AiFeatureUsageRateLimiter.mjs b/services/web/app/src/infrastructure/rate-limiters/AiFeatureUsageRateLimiter.mjs new file mode 100644 index 0000000000..0b117a2ea5 --- /dev/null +++ b/services/web/app/src/infrastructure/rate-limiters/AiFeatureUsageRateLimiter.mjs @@ -0,0 +1,59 @@ +// @ts-check + +import UserGetter from '../../Features/User/UserGetter.mjs' +import FeatureUsageRateLimiter from './FeatureUsageRateLimiter.mjs' +import Settings from '@overleaf/settings' +import SplitTestHandler from '../../Features/SplitTests/SplitTestHandler.mjs' + +class AiFeatureUsageRateLimiter extends FeatureUsageRateLimiter { + constructor() { + super('aiFeatureUsage') + } + + /** + * @param {string} userId + * @returns {Promise} + */ + async _getAllowance(userId) { + const user = await UserGetter.promises.getUser(userId, { + features: 1, + writefull: 1, + }) + // todo: quota clean-up: remove aiErrorAssistant checking, and split test + const inQuotaSplitTest = + await SplitTestHandler.promises.featureFlagEnabledForUser( + userId, + 'plans-2026-phase-1' + ) + + if (inQuotaSplitTest) { + const quotaTier = user?.writefull?.isPremium + ? Settings.writefull.quotaTierGranted + : user.features.aiUsageQuota + return _quotaTierToAllowance(quotaTier) + } else { + const DEFAULT_ALLOWANCE = 1 + const ADD_ON_ALLOWANCE = 200 + const hasAddOn = + user?.features?.aiErrorAssistant || user?.writefull?.isPremium + return hasAddOn ? ADD_ON_ALLOWANCE : DEFAULT_ALLOWANCE + } + } +} + +/** + * Maps a quota tier identifier to its corresponding numeric allowance + * using the configured quota grants for AI features. + * + * @param {string} quotaTier - The quota tier identifier for the user + * @returns {number} The numeric allowance for the given tier + */ +function _quotaTierToAllowance(quotaTier) { + const quota = Settings.quotaGrants.ai[quotaTier] + if (!quota || typeof quota !== 'number') { + throw new Error(`Quota tier "${quotaTier}" is not initialized in settings`) + } + return Math.floor(quota) +} + +export default new AiFeatureUsageRateLimiter() diff --git a/services/web/app/src/infrastructure/FeatureUsageRateLimiter.mjs b/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs similarity index 95% rename from services/web/app/src/infrastructure/FeatureUsageRateLimiter.mjs rename to services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs index f067366a81..3bd5ac920b 100644 --- a/services/web/app/src/infrastructure/FeatureUsageRateLimiter.mjs +++ b/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs @@ -1,7 +1,7 @@ // @ts-check -import { UserFeatureUsage } from '../models/UserFeatureUsage.mjs' -import { TooManyRequestsError } from '../Features/Errors/Errors.js' +import { UserFeatureUsage } from '../../models/UserFeatureUsage.mjs' +import { TooManyRequestsError } from '../../Features/Errors/Errors.js' const PERIOD = 24 // hours const PERIOD_IN_MILLISECONDS = PERIOD * 60 * 60 * 1000 @@ -166,9 +166,7 @@ export default class FeatureUsageRateLimiter { const pastUsageLimit = usage > allowance && refreshEpoch > Date.now() if (pastUsageLimit) { - throw new TooManyRequestsError( - `${this.featureName} assistant rate limit exceeded` - ) + throw new TooManyRequestsError(`${this.featureName} rate limit exceeded`) } } } diff --git a/services/web/app/src/infrastructure/rate-limiters/TokenUsageRateLimiter.mjs b/services/web/app/src/infrastructure/rate-limiters/TokenUsageRateLimiter.mjs new file mode 100644 index 0000000000..9cc25e955e --- /dev/null +++ b/services/web/app/src/infrastructure/rate-limiters/TokenUsageRateLimiter.mjs @@ -0,0 +1,171 @@ +// @ts-check +import { UserFeatureUsage } from '../../models/UserFeatureUsage.mjs' +import { TooManyRequestsError } from '../../Features/Errors/Errors.js' +import AnalyticsManager from '../../Features/Analytics/AnalyticsManager.mjs' +/** @typedef {{usage?: number | null, periodStart?: Date | null}} FeatureUsage */ +/** @typedef {{remainingUsage: number, resetDate?: string}} RemainingUsage */ + +const PERIOD = 24 // hours +const PERIOD_IN_MILLISECONDS = PERIOD * 60 * 60 * 1000 + +// todo: quota clean-up: extend this off base RateLimitController and unify behaviour where possible. + +export default class TokenUsageRateLimiter { + /** + * @param {string} featureName + */ + constructor(featureName) { + this.featureName = featureName + } + + _resetFeatureUsagePipelineSection() { + return { + $set: { + features: { + [this.featureName]: { + $cond: { + if: { + $lte: [ + { + $dateAdd: { + startDate: `$features.${this.featureName}.periodStart`, + unit: 'hour', + amount: PERIOD, + }, + }, + '$$NOW', + ], + }, + then: { + usage: 0, + periodStart: '$$NOW', + }, + else: `$features.${this.featureName}`, + }, + }, + }, + }, + } + } + + /** + * + * @param {string} _userId + * @returns {Promise} + */ + async _getAllowance(_userId) { + throw new Error('_getAllowance must be implemented by subclasses') + } + + async recordUsage(userId, res, amount) { + const allowance = await this._getAllowance(userId) + + const featureUsages = await UserFeatureUsage.findOneAndUpdate( + { _id: userId }, + [ + this._resetFeatureUsagePipelineSection(), + { + $set: { + features: { + [this.featureName]: { + usage: { + $add: [`$features.${this.featureName}.usage`, amount], + }, + }, + }, + }, + }, + ], + { + new: true, + upsert: true, + } + ).exec() + + const featureUsage = featureUsages.features?.[this.featureName] ?? {} + this.setRateLimitHeaders(res, featureUsage, allowance) + } + + /** + * + * @param {string} userId + * @returns {Promise} + */ + async getCurrentUsage(userId) { + const reportedUsage = await UserFeatureUsage.findOne({ _id: userId }).exec() + const featureUsage = reportedUsage?.features?.[this.featureName] ?? {} + return { + usage: featureUsage.usage ?? 0, + periodStart: featureUsage.periodStart ?? new Date(), + } + } + + /** + * + * @param {string} userId + * @param {import('express').Response} res + */ + async checkUsage(userId, res) { + const allowance = await this._getAllowance(userId) + const currentUsage = await this.getCurrentUsage(userId) + const periodStart = currentUsage.periodStart ?? new Date() + if (periodStart.getTime() + PERIOD_IN_MILLISECONDS <= Date.now()) { + // Period has expired, so reset usage + currentUsage.usage = 0 + currentUsage.periodStart = new Date() + } + this.setRateLimitHeaders(res, currentUsage, allowance) + if ((currentUsage.usage ?? 0) >= allowance) { + await AnalyticsManager.recordEventForUser( + userId, + 'ai-token-usage-limit-exceeded' + ) + + throw new TooManyRequestsError({ + message: `${this.featureName} rate limit exceeded`, + info: { + userId, + }, + }) + } + } + + /** + * + * @param {import('express').Response} res + * @param {FeatureUsage} featureUsage + * @param {number} allowance + */ + setRateLimitHeaders(res, featureUsage, allowance) { + const periodStart = featureUsage.periodStart ?? new Date() + const usage = featureUsage.usage ?? 0 + const refreshEpoch = periodStart.getTime() + PERIOD_IN_MILLISECONDS + const secondsTillReset = Math.ceil((refreshEpoch - Date.now()) / 1000) + + if (!res.headersSent) { + res.set('RateLimit-Limit', allowance.toString()) + res.set('RateLimit-Remaining', Math.max(0, allowance - usage).toString()) + res.set('RateLimit-Reset', Math.max(0, secondsTillReset).toString()) + } + } + + /** + * Calculates a weighted token usage based on cost incurred for different token + * types. + * + * @param {import('ai').LanguageModelUsage} tokenUsage + * @return {number} + */ + calculateTokenUsage(tokenUsage) { + const { + outputTokens, + inputTokenDetails: { noCacheTokens, cacheReadTokens }, + } = tokenUsage + + return Math.ceil( + (noCacheTokens ?? 0) + + (outputTokens ?? 0) * 10 + + (cacheReadTokens ?? 0) * 0.1 + ) + } +} diff --git a/services/web/app/src/infrastructure/rate-limiters/WorkbenchRateLimiter.mjs b/services/web/app/src/infrastructure/rate-limiters/WorkbenchRateLimiter.mjs new file mode 100644 index 0000000000..ca1ef0ddda --- /dev/null +++ b/services/web/app/src/infrastructure/rate-limiters/WorkbenchRateLimiter.mjs @@ -0,0 +1,56 @@ +// @ts-check +import SplitTestHandler from '../../Features/SplitTests/SplitTestHandler.mjs' +import UserGetter from '../../Features/User/UserGetter.mjs' +import TokenUsageRateLimiter from './TokenUsageRateLimiter.mjs' +/** @typedef {{usage?: number | null, periodStart?: Date | null}} FeatureUsage */ +/** @typedef {{remainingUsage: number, resetDate?: string}} RemainingUsage */ + +const DEFAULT_USER_TOKEN_ALLOWANCE = 8_000_000 +const ALPHA_USER_TOKEN_ALLOWANCE = 8_000_000 + +class WorkbenchRateLimiter extends TokenUsageRateLimiter { + constructor() { + super('aiWorkbench') + } + + /** + * @param {string} userId + * @returns {Promise} + */ + async _getAllowance(userId) { + const splitTestAssignment = + await SplitTestHandler.promises.getAssignmentForUser( + userId, + 'ai-workbench-release' + ) + const inSplitTest = splitTestAssignment.variant === 'enabled' + if (!inSplitTest) { + return 0 + } + const user = await UserGetter.promises.getUser(userId, { + features: 1, + writefull: 1, + alphaProgram: 1, + }) + + if (user?.alphaProgram) { + return ALPHA_USER_TOKEN_ALLOWANCE + } + + // todo: quota clean-up: remove split test + let hasAddOn + const inQuotaSplitTest = + await SplitTestHandler.promises.featureFlagEnabledForUser( + userId, + 'plans-2026-phase-1' + ) + if (inQuotaSplitTest) { + // post rollout, all users have the same token limit (fair usage) + return DEFAULT_USER_TOKEN_ALLOWANCE + } else { + hasAddOn = user.features.aiErrorAssistant || user.writefull?.isPremium + return hasAddOn ? DEFAULT_USER_TOKEN_ALLOWANCE : 0 + } + } +} +export default new WorkbenchRateLimiter() diff --git a/services/web/app/src/models/UserFeatureUsage.mjs b/services/web/app/src/models/UserFeatureUsage.mjs index 1427715dfb..0fa765a2a9 100644 --- a/services/web/app/src/models/UserFeatureUsage.mjs +++ b/services/web/app/src/models/UserFeatureUsage.mjs @@ -8,7 +8,7 @@ const Usage = new Schema({ export const UserFeatureUsageSchema = new Schema({ features: { - aiErrorAssistant: Usage, + aiFeatureUsage: Usage, aiWorkbench: Usage, }, }) diff --git a/services/web/frontend/js/shared/context/editor-context.tsx b/services/web/frontend/js/shared/context/editor-context.tsx index b6f79fec43..a3821b7b03 100644 --- a/services/web/frontend/js/shared/context/editor-context.tsx +++ b/services/web/frontend/js/shared/context/editor-context.tsx @@ -76,15 +76,15 @@ export const EditorProvider: FC = ({ children }) => { const [hasPremiumSuggestion, setHasPremiumSuggestion] = useState( () => { return Boolean( - featureUsage?.aiErrorAssistant && - featureUsage?.aiErrorAssistant.remainingUsage > 0 + featureUsage?.aiFeatureUsage && + featureUsage?.aiFeatureUsage.remainingUsage > 0 ) } ) const [premiumSuggestionResetDate, setPremiumSuggestionResetDate] = useState(() => { - return featureUsage?.aiErrorAssistant?.resetDate - ? new Date(featureUsage.aiErrorAssistant.resetDate) + return featureUsage?.aiFeatureUsage?.resetDate + ? new Date(featureUsage.aiFeatureUsage.resetDate) : new Date() }) diff --git a/services/web/test/unit/src/Project/ProjectController.test.mjs b/services/web/test/unit/src/Project/ProjectController.test.mjs index ee4fa2d0c2..07e076f77f 100644 --- a/services/web/test/unit/src/Project/ProjectController.test.mjs +++ b/services/web/test/unit/src/Project/ProjectController.test.mjs @@ -230,6 +230,12 @@ describe('ProjectController', function () { promises: { hooks: { fire: sinon.stub().resolves() } }, } + ctx.AiFeatureUsageRateLimiter = { + getRemainingFeatureUses: sinon.stub().resolves({ + aiFeatureUsage: { remainingUsage: 0 }, + }), + } + vi.doMock('mongodb-legacy', () => ({ default: { ObjectId }, })) @@ -485,6 +491,13 @@ describe('ProjectController', function () { default: ctx.Modules, })) + vi.doMock( + '../../../../app/src/infrastructure/rate-limiters/AiFeatureUsageRateLimiter', + () => ({ + default: ctx.AiFeatureUsageRateLimiter, + }) + ) + ctx.ProjectController = (await import(MODULE_PATH)).default ctx.projectName = '£12321jkj9ujkljds' diff --git a/services/web/test/unit/src/infrastructure/AiFeatureUsageRateLimiter.test.mjs b/services/web/test/unit/src/infrastructure/AiFeatureUsageRateLimiter.test.mjs new file mode 100644 index 0000000000..2a8e817706 --- /dev/null +++ b/services/web/test/unit/src/infrastructure/AiFeatureUsageRateLimiter.test.mjs @@ -0,0 +1,182 @@ +import { expect, vi } from 'vitest' +import sinon from 'sinon' +import mongodb from 'mongodb-legacy' +const ObjectId = mongodb.ObjectId + +vi.mock('../../../../../app/src/Features/Errors/Errors.js', () => { + return vi.importActual('../../../../../app/src/Features/Errors/Errors.js') +}) + +const modulePath = + '../../../../app/src/infrastructure/rate-limiters/AiFeatureUsageRateLimiter.mjs' + +describe('AiFeatureUsageRateLimiter', function () { + beforeEach(async function (ctx) { + ctx.userId = new ObjectId().toString() + + ctx.UserFeatureUsageModel = { + findOneAndUpdate: sinon.stub().returns({ + exec: sinon.stub().resolves({ + features: { + aiFeatureUsage: { + usage: 0, + periodStart: new Date(), + }, + }, + }), + }), + findOne: sinon.stub().returns({ + exec: sinon.stub().resolves({ + features: { + aiFeatureUsage: { + usage: 0, + periodStart: new Date(), + }, + }, + }), + }), + } + + ctx.user = { + features: { aiUsageQuota: 'basic' }, + writefull: { isPremium: false }, + } + ctx.userWithOLBundle = { + features: { aiUsageQuota: 'unlimited' }, + writefull: { isPremium: false }, + } + ctx.userWithOLBundleThroughWf = { + features: { aiUsageQuota: 'basic' }, + writefull: { isPremium: true }, + } + + ctx.UserGetter = { + promises: { + getUser: sinon.stub().resolves(ctx.user), + }, + } + + ctx.settings = { + writefull: { + quotaTierGranted: 'unlimited', + }, + aiFeatures: { + freeTrialQuota: 'basic', + unlimitedQuota: 'unlimited', + }, + quotaGrants: { + ai: { + basic: 5, + unlimited: 200, + }, + }, + } + + ctx.SplitTestHandler = { + promises: { + featureFlagEnabledForUser: sinon.stub().resolves(true), + }, + } + + vi.doMock('@overleaf/settings', () => ({ + default: ctx.settings, + })) + + vi.doMock('../../../../app/src/models/UserFeatureUsage', () => ({ + UserFeatureUsage: ctx.UserFeatureUsageModel, + })) + + vi.doMock('../../../../app/src/Features/User/UserGetter.mjs', () => ({ + default: ctx.UserGetter, + })) + + vi.doMock( + '../../../../app/src/Features/SplitTests/SplitTestHandler.mjs', + () => ({ + default: ctx.SplitTestHandler, + }) + ) + + const module = await import(modulePath) + ctx.AiFeatureUsageRateLimiter = module.default + }) + + describe('useFeature', function () { + describe('with some remaining allowance left', function () { + it('should suceed', async function (ctx) { + const res = { set: () => null } + await expect(ctx.AiFeatureUsageRateLimiter.useFeature(ctx.userId, res)) + .to.not.be.rejected + }) + }) + + describe('with 0 allowance left', function () { + beforeEach(function (ctx) { + ctx.UserFeatureUsageModel.findOneAndUpdate = sinon.stub().returns({ + exec: sinon.stub().resolves({ + features: { + aiFeatureUsage: { + usage: ctx.settings.quotaGrants.ai.unlimited + 1, + periodStart: new Date(), + }, + }, + }), + }) + }) + + it('should be rejected with TooManyRequestsError', async function (ctx) { + const res = { set: () => null } + await expect( + ctx.AiFeatureUsageRateLimiter.useFeature(ctx.userId, res) + ).to.be.rejectedWith('aiFeatureUsage rate limit exceeded') + }) + }) + }) + + describe('getRemainingFeatureUses', function () { + beforeEach(async function (ctx) { + ctx.UserFeatureUsageModel.findOneAndUpdate = sinon.stub().returns({ + exec: sinon.stub().resolves({ + features: { + aiFeatureUsage: { + usage: 0, + periodStart: new Date(), + }, + }, + }), + }) + ctx.UserGetter.promises.getUser = sinon.stub() + }) + + it('should give higher usage for OL assist bundle owners', async function (ctx) { + ctx.UserGetter.promises.getUser = sinon + .stub() + .resolves(ctx.userWithOLBundle) + const usages = + await ctx.AiFeatureUsageRateLimiter.getRemainingFeatureUses(ctx.userId) + await expect(usages.aiFeatureUsage.remainingUsage).to.equal( + ctx.settings.quotaGrants.ai.unlimited + ) + }) + + it('should give higher usage for assist bundle owners who have the feature via Writefull', async function (ctx) { + ctx.UserGetter.promises.getUser = sinon + .stub() + .resolves(ctx.userWithOLBundleThroughWf) + const usages = + await ctx.AiFeatureUsageRateLimiter.getRemainingFeatureUses(ctx.userId) + await expect(usages.aiFeatureUsage.remainingUsage).to.equal( + ctx.settings.quotaGrants.ai.unlimited + ) + }) + + it('should calculate remaining usages for free users', async function (ctx) { + ctx.UserGetter.promises.getUser = sinon.stub().resolves(ctx.user) + const usages = + await ctx.AiFeatureUsageRateLimiter.getRemainingFeatureUses(ctx.userId) + await expect(usages.aiFeatureUsage.remainingUsage).to.equal( + ctx.settings.quotaGrants.ai.basic + ) + }) + }) +}) diff --git a/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs b/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs index 0e5ba40677..bfcea70846 100644 --- a/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs +++ b/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs @@ -19,7 +19,7 @@ vi.mock('../../../../app/src/Features/Errors/Errors.js', () => { const MOCKED_FEATURE_NAME = 'aiWorkbench' const modulePath = - '../../../../app/src/infrastructure/FeatureUsageRateLimiter.mjs' + '../../../../app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter' describe('FeatureUsageRateLimiter', function () { beforeAll(async function () { diff --git a/services/web/test/unit/src/infrastructure/WorkbenchRateLimiter.sequential.test.mjs b/services/web/test/unit/src/infrastructure/WorkbenchRateLimiter.sequential.test.mjs new file mode 100644 index 0000000000..9de808d694 --- /dev/null +++ b/services/web/test/unit/src/infrastructure/WorkbenchRateLimiter.sequential.test.mjs @@ -0,0 +1,416 @@ +import { beforeAll, beforeEach, describe, it, vi, expect } from 'vitest' +import sinon from 'sinon' +import mongodb from 'mongodb-legacy' +import { + cleanupTestDatabase, + db, + waitForDb, +} from '../../../../app/src/infrastructure/mongodb.mjs' +import { UserFeatureUsage } from '../../../../app/src/models/UserFeatureUsage.mjs' + +const { ObjectId } = mongodb + +const MODULE_PATH = + '../../../../app/src/infrastructure/rate-limiters/WorkbenchRateLimiter' + +describe('WorkbenchRateLimiter', function () { + beforeAll(async function () { + await waitForDb() + }) + beforeAll(cleanupTestDatabase) + + beforeEach(async function (ctx) { + ctx.alphaUserId = new ObjectId() + ctx.alphaUser = { + _id: ctx.alphaUserId, + alphaProgram: true, + features: { + aiUsageQuota: 'unlimited', + }, + } + ctx.userWithoutAiAddOnId = new ObjectId() + ctx.userWithAiAddOn = { + _id: ctx.userWithoutAiAddOnId, + features: { + aiUsageQuota: 'unlimited', + }, + alphaProgram: false, + } + ctx.otherUserId = new ObjectId() + ctx.otherUser = { + _id: ctx.otherUserId, + features: { + aiUsageQuota: 'basic', + }, + alphaProgram: false, + } + ctx.UserGetter = { + promises: { + getUser: sinon.stub(), + }, + } + ctx.UserGetter.promises.getUser + .withArgs(ctx.alphaUserId) + .resolves(ctx.alphaUser) + ctx.UserGetter.promises.getUser + .withArgs(ctx.userWithoutAiAddOnId) + .resolves(ctx.userWithAiAddOn) + ctx.UserGetter.promises.getUser + .withArgs(ctx.otherUserId) + .resolves(ctx.otherUser) + + ctx.SplitTestHandler = { + promises: { + getAssignmentForUser: sinon.stub(), + featureFlagEnabledForUser: sinon.stub().resolves(true), + }, + } + ctx.SplitTestHandler.promises.getAssignmentForUser + .withArgs(ctx.alphaUserId, 'ai-workbench-release') + .resolves({ variant: 'enabled' }) + + vi.doMock('../../../../app/src/infrastructure/mongodb', () => ({ + ObjectId, + db, + waitForDb, + })) + + vi.doMock('../../../../app/src/Features/User/UserGetter', () => ({ + default: ctx.UserGetter, + })) + + vi.doMock( + '../../../../app/src/Features/SplitTests/SplitTestHandler', + () => ({ + default: ctx.SplitTestHandler, + }) + ) + + vi.doMock( + '../../../../app/src/Features/Analytics/AnalyticsManager', + () => ({ + default: { + recordEventForUser: sinon.stub(), + }, + }) + ) + + ctx.WorkbenchRateLimiter = (await import(MODULE_PATH)).default + }) + + describe('calculateTokenUsage', function () { + it('treats input tokens as 1', function (ctx) { + expect( + ctx.WorkbenchRateLimiter.calculateTokenUsage({ + inputTokenDetails: { + noCacheTokens: 100, + cacheReadTokens: 0, + }, + outputTokens: 0, + }) + ).to.equal(100) + }) + + it('treats output tokens as 10', function (ctx) { + expect( + ctx.WorkbenchRateLimiter.calculateTokenUsage({ + inputTokenDetails: { + noCacheTokens: 0, + cacheReadTokens: 0, + }, + outputTokens: 100, + }) + ).to.equal(1000) + }) + + it('treats output tokens correctly', function (ctx) { + expect( + ctx.WorkbenchRateLimiter.calculateTokenUsage({ + inputTokenDetails: { + noCacheTokens: 0, + cacheReadTokens: 0, + }, + outputTokens: 100, + }) + ).to.equal(1000) + }) + + it('rounds up to nearest integer', function (ctx) { + expect( + ctx.WorkbenchRateLimiter.calculateTokenUsage({ + inputTokenDetails: { + noCacheTokens: 1, + cacheReadTokens: 0, + }, + outputTokens: 0, + }) + ).to.equal(1) + }) + + it('sums mixed tokens', function (ctx) { + expect( + ctx.WorkbenchRateLimiter.calculateTokenUsage({ + inputTokenDetails: { + noCacheTokens: 10, + cacheReadTokens: 10, + }, + outputTokens: 10, + }) + ).to.equal(10 + 100 + 0 + 1) + }) + }) + + describe('checkUsage', function () { + describe('with no data', function () { + beforeEach(async function (ctx) { + await UserFeatureUsage.deleteMany({}).exec() + ctx.res = { + set: sinon.stub(), + headersSent: false, + } + }) + + it('should not throw', async function (ctx) { + await expect( + ctx.WorkbenchRateLimiter.checkUsage(ctx.alphaUserId, ctx.res) + ).to.eventually.be.fulfilled + }) + + it('sets rate limit headers', async function (ctx) { + await ctx.WorkbenchRateLimiter.checkUsage(ctx.alphaUserId, ctx.res) + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Limit', + '8000000' + ) + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Remaining', + '8000000' + ) + // We can't mock the mongo date, so just check that something was set + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Reset', + matchRateLimit(24 * 60 * 60) + ) + }) + }) + + describe('with existing usage', function () { + beforeEach(async function (ctx) { + await UserFeatureUsage.deleteMany({}).exec() + ctx.res = { + set: sinon.stub(), + headersSent: false, + } + const usageRecord = new UserFeatureUsage({ + _id: ctx.alphaUserId, + features: { + aiWorkbench: { + usage: 2000000, + periodStart: new Date(new Date().getTime() - 1 * 60 * 60 * 1000), // 1 hour ago + }, + }, + }) + await usageRecord.save() + }) + + it('should not throw if under limit', async function (ctx) { + await expect( + ctx.WorkbenchRateLimiter.checkUsage(ctx.alphaUserId, ctx.res) + ).to.eventually.be.fulfilled + }) + + it('sets rate limit headers', async function (ctx) { + await ctx.WorkbenchRateLimiter.checkUsage(ctx.alphaUserId, ctx.res) + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Limit', + '8000000' + ) + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Remaining', + '6000000' + ) + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Reset', + matchRateLimit(23 * 60 * 60) + ) + }) + + it('throws if over limit', async function (ctx) { + const usageRecord = await UserFeatureUsage.findById( + ctx.alphaUserId + ).exec() + usageRecord.features.aiWorkbench.usage = 9000000 + await usageRecord.save() + + await expect( + ctx.WorkbenchRateLimiter.checkUsage(ctx.alphaUserId, ctx.res) + ).to.eventually.be.rejectedWith(/rate limit exceeded/i) + }) + }) + + describe('with an expired old usage period', function () { + beforeEach(async function (ctx) { + await UserFeatureUsage.deleteMany({}).exec() + ctx.res = { + set: sinon.stub(), + headersSent: false, + } + const usageRecord = new UserFeatureUsage({ + _id: ctx.alphaUserId, + features: { + aiWorkbench: { + usage: 2000000, + periodStart: new Date(new Date().getTime() - 25 * 60 * 60 * 1000), // 25 hours ago + }, + }, + }) + await usageRecord.save() + }) + + it('should not throw', async function (ctx) { + await expect( + ctx.WorkbenchRateLimiter.checkUsage(ctx.alphaUserId, ctx.res) + ).to.eventually.be.fulfilled + }) + + it('sets rate limit headers', async function (ctx) { + await ctx.WorkbenchRateLimiter.checkUsage(ctx.alphaUserId, ctx.res) + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Limit', + '8000000' + ) + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Remaining', + '8000000' + ) + // A new period + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Reset', + matchRateLimit(24 * 60 * 60) + ) + }) + }) + }) + + describe('recordUsage', function () { + beforeEach(async function (ctx) { + await UserFeatureUsage.deleteMany({}).exec() + ctx.res = { + set: sinon.stub(), + headersSent: false, + } + }) + + describe('without existing usage', function () { + it('creates new usage record if none exists', async function (ctx) { + await ctx.WorkbenchRateLimiter.recordUsage( + ctx.alphaUserId, + ctx.res, + 1500000 + ) + const usageRecord = await UserFeatureUsage.findById( + ctx.alphaUserId + ).exec() + expect(usageRecord).to.exist + expect(usageRecord.features.aiWorkbench.usage).to.equal(1500000) + expect( + usageRecord.features.aiWorkbench.periodStart.getTime() + ).to.approximately(new Date().getTime(), 60_000) + }) + }) + + describe('with existing usage', function () { + beforeEach(async function (ctx) { + await UserFeatureUsage.deleteMany({}).exec() + const usageRecord = new UserFeatureUsage({ + _id: ctx.alphaUserId, + features: { + aiWorkbench: { + usage: 2000000, + periodStart: new Date(new Date().getTime() - 1 * 60 * 60 * 1000), // 1 hour ago + }, + }, + }) + await usageRecord.save() + await ctx.WorkbenchRateLimiter.recordUsage( + ctx.alphaUserId, + ctx.res, + 1000000 + ) + }) + + it('updates existing usage record', async function (ctx) { + const updatedRecord = await UserFeatureUsage.findById( + ctx.alphaUserId + ).exec() + expect(updatedRecord.features.aiWorkbench.usage).to.equal(3000000) + }) + it('sets rate limit headers', async function (ctx) { + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Limit', + '8000000' + ) + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Remaining', + '5000000' + ) + // Keeps the original period start time + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Reset', + matchRateLimit(23 * 60 * 60) + ) + }) + }) + + describe('with an expired old usage period', function () { + beforeEach(async function (ctx) { + await UserFeatureUsage.deleteMany({}).exec() + const usageRecord = new UserFeatureUsage({ + _id: ctx.alphaUserId, + features: { + aiWorkbench: { + usage: 2000000, + periodStart: new Date(new Date().getTime() - 25 * 60 * 60 * 1000), // 25 hours ago + }, + }, + }) + await usageRecord.save() + await ctx.WorkbenchRateLimiter.recordUsage( + ctx.alphaUserId, + ctx.res, + 1000000 + ) + }) + + it('resets usage and period start', async function (ctx) { + const updatedRecord = await UserFeatureUsage.findById( + ctx.alphaUserId + ).exec() + expect(updatedRecord.features.aiWorkbench.usage).to.equal(1000000) + }) + + it('sets rate limit headers', async function (ctx) { + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Limit', + '8000000' + ) + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Remaining', + '7000000' + ) + // New period start time + expect(ctx.res.set).to.have.been.calledWith( + 'RateLimit-Reset', + matchRateLimit(24 * 60 * 60) + ) + }) + }) + }) +}) + +function matchRateLimit(expectedValue, delta = 60) { + return sinon.match(function (value) { + const number = parseInt(value, 10) + return Math.abs(number - expectedValue) <= delta + }, `${expectedValue} ± ${delta}`) +}