diff --git a/services/web/app/src/Features/Subscription/SubscriptionHandler.mjs b/services/web/app/src/Features/Subscription/SubscriptionHandler.mjs index 930838d662..f04607859a 100644 --- a/services/web/app/src/Features/Subscription/SubscriptionHandler.mjs +++ b/services/web/app/src/Features/Subscription/SubscriptionHandler.mjs @@ -13,6 +13,8 @@ import UserUpdater from '../User/UserUpdater.mjs' import Modules from '../../infrastructure/Modules.mjs' import { AI_ADD_ON_CODE } from './AiHelper.mjs' import CustomerIoPlanHelpers from './CustomerIoPlanHelpers.mjs' +import WorkbenchRateLimiter from '../../infrastructure/rate-limiters/WorkbenchRateLimiter.mjs' +import AiFeatureUsageRateLimiter from '../../infrastructure/rate-limiters/AiFeatureUsageRateLimiter.mjs' /** * @import { PaymentProviderSubscriptionChange } from './PaymentProviderEntities.mjs' @@ -130,6 +132,13 @@ async function updateSubscription(user, planCode) { user._id ) + try { + await WorkbenchRateLimiter.resetTokenUsage(user._id) + await AiFeatureUsageRateLimiter.resetFeatureUsage(user._id) + } catch (err) { + logger.error({ err, userId: user._id }, 'failed to reset AI usage limits') + } + if (previousPlanType) { Modules.promises.hooks .fire('setUserProperties', user._id, { diff --git a/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs b/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs index 78326b7dd3..04f476aaff 100644 --- a/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs +++ b/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs @@ -20,7 +20,7 @@ export default class FeatureUsageRateLimiter { this.featureName = featureName } - _resetFeatureUsagePipelineSection() { + resetFeatureUsagePipelineSection() { return { $set: { features: { @@ -62,7 +62,7 @@ export default class FeatureUsageRateLimiter { const featureUsages = await UserFeatureUsage.findOneAndUpdate( { _id: userId }, [ - this._resetFeatureUsagePipelineSection(), + this.resetFeatureUsagePipelineSection(), { $set: { features: { @@ -108,7 +108,7 @@ export default class FeatureUsageRateLimiter { const featureUsages = await UserFeatureUsage.findOneAndUpdate( { _id: userId }, [ - this._resetFeatureUsagePipelineSection(), + this.resetFeatureUsagePipelineSection(), { $set: { [`features.${this.featureName}.usage`]: { @@ -130,6 +130,24 @@ export default class FeatureUsageRateLimiter { setRateLimitHeaders(res, featureUsage, allowance) } + /** + * @param {string} userId + */ + async resetFeatureUsage(userId) { + await UserFeatureUsage.findOneAndUpdate( + { _id: userId }, + { + $set: { + [`features.${this.featureName}`]: { + usage: 0, + periodStart: new Date(), + }, + }, + }, + { upsert: true } + ).exec() + } + /** * @param {string} userId * @returns {Promise<{[featureName: string]: { remainingUsage: number, resetDate?: string}}>} diff --git a/services/web/app/src/infrastructure/rate-limiters/TokenUsageRateLimiter.mjs b/services/web/app/src/infrastructure/rate-limiters/TokenUsageRateLimiter.mjs index 79afdf75f7..39e1b23d69 100644 --- a/services/web/app/src/infrastructure/rate-limiters/TokenUsageRateLimiter.mjs +++ b/services/web/app/src/infrastructure/rate-limiters/TokenUsageRateLimiter.mjs @@ -18,7 +18,7 @@ export default class TokenUsageRateLimiter { this.featureName = featureName } - _resetFeatureUsagePipelineSection() { + resetTokenUsagePipelineSection() { return { $set: { features: { @@ -69,7 +69,7 @@ export default class TokenUsageRateLimiter { const featureUsages = await UserFeatureUsage.findOneAndUpdate( { _id: userId }, [ - this._resetFeatureUsagePipelineSection(), + this.resetTokenUsagePipelineSection(), { $set: { features: { @@ -92,6 +92,24 @@ export default class TokenUsageRateLimiter { this.setRateLimitHeaders(res, featureUsage, allowance) } + /** + * @param {string} userId + */ + async resetTokenUsage(userId) { + await UserFeatureUsage.findOneAndUpdate( + { _id: userId }, + { + $set: { + [`features.${this.featureName}`]: { + usage: 0, + periodStart: new Date(), + }, + }, + }, + { upsert: true } + ).exec() + } + /** * * @param {string} userId diff --git a/services/web/test/unit/src/Subscription/SubscriptionHandler.test.mjs b/services/web/test/unit/src/Subscription/SubscriptionHandler.test.mjs index 667efd483c..0b41882900 100644 --- a/services/web/test/unit/src/Subscription/SubscriptionHandler.test.mjs +++ b/services/web/test/unit/src/Subscription/SubscriptionHandler.test.mjs @@ -145,6 +145,14 @@ describe('SubscriptionHandler', function () { }, } + ctx.WorkbenchRateLimiter = { + resetTokenUsage: sinon.stub().resolves(), + } + + ctx.AiFeatureUsageRateLimiter = { + resetFeatureUsage: sinon.stub().resolves(), + } + vi.doMock( '../../../../app/src/Features/Subscription/RecurlyWrapper', () => ({ @@ -227,6 +235,20 @@ describe('SubscriptionHandler', function () { }), })) + vi.doMock( + '../../../../app/src/infrastructure/rate-limiters/WorkbenchRateLimiter', + () => ({ + default: ctx.WorkbenchRateLimiter, + }) + ) + + vi.doMock( + '../../../../app/src/infrastructure/rate-limiters/AiFeatureUsageRateLimiter', + () => ({ + default: ctx.AiFeatureUsageRateLimiter, + }) + ) + ctx.SubscriptionHandler = (await import(MODULE_PATH)).default }) @@ -379,6 +401,37 @@ describe('SubscriptionHandler', function () { ctx.user._id ) }) + + it('should reset the ai rate limiter usages on a successful update', async function (ctx) { + ctx.LimitationsManager.promises.userHasSubscription.resolves({ + hasSubscription: true, + subscription: ctx.subscription, + }) + await ctx.SubscriptionHandler.promises.updateSubscription( + ctx.user, + ctx.plan_code + ) + expect(ctx.WorkbenchRateLimiter.resetTokenUsage).to.have.been.calledWith( + ctx.user._id + ) + expect( + ctx.AiFeatureUsageRateLimiter.resetFeatureUsage + ).to.have.been.calledWith(ctx.user._id) + }) + + it('should not reset the ai rate limiter usages when no subscription exists', async function (ctx) { + ctx.LimitationsManager.promises.userHasSubscription.resolves({ + hasSubscription: false, + subscription: null, + }) + await ctx.SubscriptionHandler.promises.updateSubscription( + ctx.user, + ctx.plan_code + ) + expect(ctx.WorkbenchRateLimiter.resetTokenUsage).to.not.have.been.called + expect(ctx.AiFeatureUsageRateLimiter.resetFeatureUsage).to.not.have.been + .called + }) }) describe('cancelPendingSubscriptionChange', function () { diff --git a/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs b/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs index b71069f9a7..05f6a76f5c 100644 --- a/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs +++ b/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs @@ -174,6 +174,48 @@ describe('FeatureUsageRateLimiter', function () { }) }) + describe('resetFeatureUsage', function () { + beforeEach(function (ctx) { + ctx._getAllowanceStub.resolves(100) + }) + + describe('with some usage', function () { + beforeEach(async function (ctx) { + await UserFeatureUsage.create({ + _id: ctx.userId, + features: { + [MOCKED_FEATURE_NAME]: { usage: 75, periodStart: new Date(0) }, + }, + }) + }) + + it('should reset usage back to the full allowance', async function (ctx) { + await ctx.FeatureUsageRateLimiter.resetFeatureUsage(ctx.userId) + const usages = + await ctx.FeatureUsageRateLimiter.getRemainingFeatureUses(ctx.userId) + expect(usages[MOCKED_FEATURE_NAME].remainingUsage).to.equal(100) + }) + + it('should set periodStart to roughly the current time', async function (ctx) { + const before = Date.now() + await ctx.FeatureUsageRateLimiter.resetFeatureUsage(ctx.userId) + const doc = await UserFeatureUsage.findOne({ _id: ctx.userId }).exec() + const periodStart = doc.features[MOCKED_FEATURE_NAME].periodStart + expect(periodStart.getTime()).to.be.at.least(before) + expect(periodStart.getTime()).to.be.at.most(Date.now()) + }) + }) + + describe('when no usage record exists', function () { + it('should upsert a fresh usage record with zero usage', async function (ctx) { + await ctx.FeatureUsageRateLimiter.resetFeatureUsage(ctx.userId) + const doc = await UserFeatureUsage.findOne({ _id: ctx.userId }).exec() + expect(doc).to.not.be.null + expect(doc.features[MOCKED_FEATURE_NAME].usage).to.equal(0) + }) + }) + }) + describe('decrementFeatureUsage', function () { describe('with some usage', function () { beforeEach(async function (ctx) { diff --git a/services/web/test/unit/src/infrastructure/WorkbenchRateLimiter.sequential.test.mjs b/services/web/test/unit/src/infrastructure/WorkbenchRateLimiter.sequential.test.mjs index 01317867f7..131442ecb7 100644 --- a/services/web/test/unit/src/infrastructure/WorkbenchRateLimiter.sequential.test.mjs +++ b/services/web/test/unit/src/infrastructure/WorkbenchRateLimiter.sequential.test.mjs @@ -292,6 +292,42 @@ describe('WorkbenchRateLimiter', function () { }) }) + describe('resetTokenUsage', function () { + beforeEach(async function () { + await UserFeatureUsage.deleteMany({}).exec() + }) + + it('resets usage to 0 and refreshes periodStart when existing usage is present', async function (ctx) { + const usageRecord = new UserFeatureUsage({ + _id: ctx.alphaUserId, + features: { + aiWorkbench: { + usage: 5000000, + periodStart: new Date(new Date().getTime() - 1 * 60 * 60 * 1000), + }, + }, + }) + await usageRecord.save() + + const before = Date.now() + await ctx.WorkbenchRateLimiter.resetTokenUsage(ctx.alphaUserId) + + const updated = await UserFeatureUsage.findById(ctx.alphaUserId).exec() + expect(updated.features.aiWorkbench.usage).to.equal(0) + expect(updated.features.aiWorkbench.periodStart.getTime()).to.be.at.least( + before + ) + }) + + it('upserts a fresh usage record with zero usage when none exists', async function (ctx) { + await ctx.WorkbenchRateLimiter.resetTokenUsage(ctx.alphaUserId) + + const created = await UserFeatureUsage.findById(ctx.alphaUserId).exec() + expect(created).to.exist + expect(created.features.aiWorkbench.usage).to.equal(0) + }) + }) + describe('recordUsage', function () { beforeEach(async function (ctx) { await UserFeatureUsage.deleteMany({}).exec()