diff --git a/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs b/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs index 3bd5ac920b..921ec3000e 100644 --- a/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs +++ b/services/web/app/src/infrastructure/rate-limiters/FeatureUsageRateLimiter.mjs @@ -53,9 +53,10 @@ export default class FeatureUsageRateLimiter { /** * * @param {string} userId + * @param {number} cost - the amount to increment the users usage by, may be 0 for features that are quota locked but dont consume any uses * @param {import('express').Response} res */ - async useFeature(userId, res) { + async useFeature(userId, res, cost = 1) { const allowance = await this._getAllowance(userId) const featureUsages = await UserFeatureUsage.findOneAndUpdate( @@ -72,7 +73,7 @@ export default class FeatureUsageRateLimiter { $lte: [`$features.${this.featureName}.usage`, allowance], }, then: { - $add: [`$features.${this.featureName}.usage`, 1], + $add: [`$features.${this.featureName}.usage`, cost], }, else: `$features.${this.featureName}.usage`, }, @@ -97,7 +98,7 @@ export default class FeatureUsageRateLimiter { * @param {string} userId * @param {import('express').Response} res */ - async decrementFeatureUsage(userId, res) { + async decrementFeatureUsage(userId, res, cost = 1) { const allowance = await this._getAllowance(userId) const featureUsages = await UserFeatureUsage.findOneAndUpdate( { _id: userId }, @@ -106,7 +107,7 @@ export default class FeatureUsageRateLimiter { { $set: { [`features.${this.featureName}.usage`]: { - $add: [`$features.${this.featureName}.usage`, -1], + $add: [`$features.${this.featureName}.usage`, -cost], }, }, }, diff --git a/services/web/test/unit/src/infrastructure/AiFeatureUsageRateLimiter.test.mjs b/services/web/test/unit/src/infrastructure/AiFeatureUsageRateLimiter.test.mjs index 2a8e817706..27c61113c9 100644 --- a/services/web/test/unit/src/infrastructure/AiFeatureUsageRateLimiter.test.mjs +++ b/services/web/test/unit/src/infrastructure/AiFeatureUsageRateLimiter.test.mjs @@ -104,6 +104,20 @@ describe('AiFeatureUsageRateLimiter', function () { describe('useFeature', function () { describe('with some remaining allowance left', function () { it('should suceed', async function (ctx) { + const res = { set: () => null } + await expect( + ctx.AiFeatureUsageRateLimiter.useFeature(ctx.userId, res, 1) + ).to.not.be.rejected + }) + + it('should succeed with cost=0', async function (ctx) { + const res = { set: () => null } + await expect( + ctx.AiFeatureUsageRateLimiter.useFeature(ctx.userId, res, 0) + ).to.not.be.rejected + }) + + it('should succeed with default cost when cost is omitted', async function (ctx) { const res = { set: () => null } await expect(ctx.AiFeatureUsageRateLimiter.useFeature(ctx.userId, res)) .to.not.be.rejected @@ -127,7 +141,7 @@ describe('AiFeatureUsageRateLimiter', function () { it('should be rejected with TooManyRequestsError', async function (ctx) { const res = { set: () => null } await expect( - ctx.AiFeatureUsageRateLimiter.useFeature(ctx.userId, res) + ctx.AiFeatureUsageRateLimiter.useFeature(ctx.userId, res, 1) ).to.be.rejectedWith('aiFeatureUsage rate limit exceeded') }) }) @@ -179,4 +193,30 @@ describe('AiFeatureUsageRateLimiter', function () { ) }) }) + + describe('decrementFeatureUsage', function () { + it('should call findOneAndUpdate to decrement usage', async function (ctx) { + const res = { set: () => null } + await ctx.AiFeatureUsageRateLimiter.decrementFeatureUsage( + ctx.userId, + res, + 1 + ) + expect(ctx.UserFeatureUsageModel.findOneAndUpdate).to.have.been.called + }) + + it('should accept a custom cost parameter', async function (ctx) { + const res = { set: () => null } + await expect( + ctx.AiFeatureUsageRateLimiter.decrementFeatureUsage(ctx.userId, res, 3) + ).to.not.be.rejected + }) + + it('should use default cost of 1 when cost is omitted', async function (ctx) { + const res = { set: () => null } + await expect( + ctx.AiFeatureUsageRateLimiter.decrementFeatureUsage(ctx.userId, res) + ).to.not.be.rejected + }) + }) }) diff --git a/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs b/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs index bfcea70846..b71069f9a7 100644 --- a/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs +++ b/services/web/test/unit/src/infrastructure/FeatureUsageRateLimiter.sequential.test.mjs @@ -52,8 +52,8 @@ describe('FeatureUsageRateLimiter', function () { describe('with no usage', function (ctx) { it('should succeed', async function (ctx) { const res = { set: () => null } - await expect(ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res)).to - .not.be.rejected + await expect(ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res, 1)) + .to.not.be.rejected }) }) @@ -69,8 +69,8 @@ describe('FeatureUsageRateLimiter', function () { it('should suceed', async function (ctx) { const res = { set: () => null } - await expect(ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res)).to - .not.be.rejected + await expect(ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res, 1)) + .to.not.be.rejected }) }) @@ -87,10 +87,60 @@ describe('FeatureUsageRateLimiter', function () { it('should be rejected with TooManyRequestsError', async function (ctx) { const res = { set: () => null } await expect( - ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res) + ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res, 1) ).to.be.rejectedWith(TooManyRequestsError) }) }) + + describe('with cost=0', function () { + beforeEach(async function (ctx) { + await UserFeatureUsage.create({ + _id: ctx.userId, + features: { + [MOCKED_FEATURE_NAME]: { usage: 50, periodStart: new Date() }, + }, + }) + }) + + it('should not increment usage', async function (ctx) { + const res = { set: () => null } + await ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res, 0) + const usages = + await ctx.FeatureUsageRateLimiter.getRemainingFeatureUses(ctx.userId) + expect(usages[MOCKED_FEATURE_NAME].remainingUsage).to.equal(50) + }) + + it('should still be rejected when over the limit', async function (ctx) { + await UserFeatureUsage.findOneAndUpdate( + { _id: ctx.userId }, + { $set: { [`features.${MOCKED_FEATURE_NAME}.usage`]: 101 } } + ) + const res = { set: () => null } + await expect( + ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res, 0) + ).to.be.rejectedWith(TooManyRequestsError) + }) + }) + + describe('with cost greater than 1', function () { + it('should increment usage by the specified cost', async function (ctx) { + const res = { set: () => null } + await ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res, 5) + const usages = + await ctx.FeatureUsageRateLimiter.getRemainingFeatureUses(ctx.userId) + expect(usages[MOCKED_FEATURE_NAME].remainingUsage).to.equal(95) + }) + }) + + describe('with default cost parameter', function () { + it('should increment usage by 1 when cost is omitted', async function (ctx) { + const res = { set: () => null } + await ctx.FeatureUsageRateLimiter.useFeature(ctx.userId, res) + const usages = + await ctx.FeatureUsageRateLimiter.getRemainingFeatureUses(ctx.userId) + expect(usages[MOCKED_FEATURE_NAME].remainingUsage).to.equal(99) + }) + }) }) describe('getRemainingFeatureUses', function () { @@ -136,7 +186,31 @@ describe('FeatureUsageRateLimiter', function () { ctx._getAllowanceStub.resolves(100) }) - it('should return a usage', async function (ctx) { + it('should decrement usage by 1 when cost is 1', async function (ctx) { + const res = { set: () => null } + await ctx.FeatureUsageRateLimiter.decrementFeatureUsage( + ctx.userId, + res, + 1 + ) + const usages = + await ctx.FeatureUsageRateLimiter.getRemainingFeatureUses(ctx.userId) + expect(usages[MOCKED_FEATURE_NAME].remainingUsage).to.equal(71) + }) + + it('should decrement usage by the specified cost', async function (ctx) { + const res = { set: () => null } + await ctx.FeatureUsageRateLimiter.decrementFeatureUsage( + ctx.userId, + res, + 5 + ) + const usages = + await ctx.FeatureUsageRateLimiter.getRemainingFeatureUses(ctx.userId) + expect(usages[MOCKED_FEATURE_NAME].remainingUsage).to.equal(75) + }) + + it('should decrement usage by 1 when cost is omitted (default)', async function (ctx) { const res = { set: () => null } await ctx.FeatureUsageRateLimiter.decrementFeatureUsage(ctx.userId, res) const usages =