diff --git a/src/shared/utils/pricing.ts b/src/shared/utils/pricing.ts new file mode 100644 index 00000000..6347b122 --- /dev/null +++ b/src/shared/utils/pricing.ts @@ -0,0 +1,111 @@ +import pricingData from '../../../resources/pricing.json'; + +export interface LiteLLMPricing { + input_cost_per_token: number; + output_cost_per_token: number; + cache_creation_input_token_cost?: number; + cache_read_input_token_cost?: number; + input_cost_per_token_above_200k_tokens?: number; + output_cost_per_token_above_200k_tokens?: number; + cache_creation_input_token_cost_above_200k_tokens?: number; + cache_read_input_token_cost_above_200k_tokens?: number; + [key: string]: unknown; +} + +export interface DisplayPricing { + input: number; + output: number; + cache_read: number; + cache_creation: number; +} + +const TIER_THRESHOLD = 200_000; + +const pricingMap = pricingData as Record; + +function tryGetPricing(key: string): LiteLLMPricing | null { + const entry = pricingMap[key]; + if ( + entry && + typeof entry === 'object' && + 'input_cost_per_token' in entry && + 'output_cost_per_token' in entry + ) { + return entry as LiteLLMPricing; + } + return null; +} + +export function getPricing(modelName: string): LiteLLMPricing | null { + const exact = tryGetPricing(modelName); + if (exact) return exact; + + const lowerName = modelName.toLowerCase(); + const lower = tryGetPricing(lowerName); + if (lower) return lower; + + for (const key of Object.keys(pricingMap)) { + if (key.toLowerCase() === lowerName) { + const match = tryGetPricing(key); + if (match) return match; + } + } + + return null; +} + +export function calculateTieredCost(tokens: number, baseRate: number, tieredRate?: number): number { + if (tokens <= 0) return 0; + if (!tieredRate || tokens <= TIER_THRESHOLD) { + return tokens * baseRate; + } + const costBelow = TIER_THRESHOLD * baseRate; + const costAbove = (tokens - TIER_THRESHOLD) * tieredRate; + return costBelow + costAbove; +} + +export function calculateMessageCost( + modelName: string, + inputTokens: number, + outputTokens: number, + cacheReadTokens: number, + cacheCreationTokens: number +): number { + const pricing = getPricing(modelName); + if (!pricing) return 0; + + const inputCost = calculateTieredCost( + inputTokens, + pricing.input_cost_per_token, + pricing.input_cost_per_token_above_200k_tokens + ); + const outputCost = calculateTieredCost( + outputTokens, + pricing.output_cost_per_token, + pricing.output_cost_per_token_above_200k_tokens + ); + const cacheCreationCost = calculateTieredCost( + cacheCreationTokens, + pricing.cache_creation_input_token_cost ?? 0, + pricing.cache_creation_input_token_cost_above_200k_tokens + ); + const cacheReadCost = calculateTieredCost( + cacheReadTokens, + pricing.cache_read_input_token_cost ?? 0, + pricing.cache_read_input_token_cost_above_200k_tokens + ); + + return inputCost + outputCost + cacheCreationCost + cacheReadCost; +} + +export function getDisplayPricing(modelName: string): DisplayPricing | null { + const pricing = getPricing(modelName); + if (!pricing) return null; + + return { + input: pricing.input_cost_per_token * 1_000_000, + output: pricing.output_cost_per_token * 1_000_000, + cache_read: (pricing.cache_read_input_token_cost ?? 0) * 1_000_000, + cache_creation: (pricing.cache_creation_input_token_cost ?? 0) * 1_000_000, + }; +} diff --git a/test/shared/utils/pricing.test.ts b/test/shared/utils/pricing.test.ts new file mode 100644 index 00000000..dff7e6e1 --- /dev/null +++ b/test/shared/utils/pricing.test.ts @@ -0,0 +1,80 @@ +import { describe, it, expect } from 'vitest'; +import { + getPricing, + calculateTieredCost, + calculateMessageCost, + getDisplayPricing, +} from '@shared/utils/pricing'; + +describe('Shared Pricing Module', () => { + describe('getPricing', () => { + it('should find pricing by exact model name', () => { + const pricing = getPricing('claude-3-5-sonnet-20241022'); + expect(pricing).not.toBeNull(); + expect(pricing!.input_cost_per_token).toBeGreaterThan(0); + expect(pricing!.output_cost_per_token).toBeGreaterThan(0); + }); + + it('should find pricing case-insensitively', () => { + const pricing = getPricing('Claude-3-5-Sonnet-20241022'); + expect(pricing).not.toBeNull(); + }); + + it('should return null for unknown models', () => { + const pricing = getPricing('totally-fake-model-xyz'); + expect(pricing).toBeNull(); + }); + }); + + describe('calculateTieredCost', () => { + it('should use base rate for tokens below 200k', () => { + const cost = calculateTieredCost(100_000, 0.000003); + expect(cost).toBeCloseTo(0.3, 6); + }); + + it('should apply tiered rate above 200k', () => { + const cost = calculateTieredCost(250_000, 0.000003, 0.000006); + expect(cost).toBeCloseTo(0.9, 6); + }); + + it('should use base rate when no tiered rate provided', () => { + const cost = calculateTieredCost(250_000, 0.000015); + expect(cost).toBeCloseTo(3.75, 6); + }); + + it('should return 0 for zero or negative tokens', () => { + expect(calculateTieredCost(0, 0.000003)).toBe(0); + expect(calculateTieredCost(-100, 0.000003)).toBe(0); + }); + }); + + describe('calculateMessageCost', () => { + it('should compute cost for a known model', () => { + const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 0, 0); + expect(cost).toBeCloseTo(0.0105, 6); + }); + + it('should return 0 for unknown models', () => { + const cost = calculateMessageCost('unknown-model', 1000, 500, 0, 0); + expect(cost).toBe(0); + }); + + it('should include cache token costs', () => { + const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 300, 200); + expect(cost).toBeGreaterThan(0.0105); + }); + }); + + describe('getDisplayPricing', () => { + it('should return per-million rates for a known model', () => { + const dp = getDisplayPricing('claude-3-5-sonnet-20241022'); + expect(dp).not.toBeNull(); + expect(dp!.input).toBeCloseTo(3.0, 1); + expect(dp!.output).toBeCloseTo(15.0, 1); + }); + + it('should return null for unknown models', () => { + expect(getDisplayPricing('unknown-model')).toBeNull(); + }); + }); +});