diff --git a/src/main/utils/jsonl.ts b/src/main/utils/jsonl.ts index 114090d0..60297c7f 100644 --- a/src/main/utils/jsonl.ts +++ b/src/main/utils/jsonl.ts @@ -9,6 +9,7 @@ import { isCommandOutputContent, sanitizeDisplayContent } from '@shared/utils/contentSanitizer'; import { createLogger } from '@shared/utils/logger'; +import { calculateMessageCost } from '@shared/utils/pricing'; import * as readline from 'readline'; import { LocalFileSystemProvider } from '../services/infrastructure/LocalFileSystemProvider'; @@ -212,113 +213,6 @@ function parseMessageType(type?: string): MessageType | null { } } -// ============================================================================= -// Cost Calculation -// ============================================================================= - -import * as fs from 'fs'; -import * as path from 'path'; - -interface ModelPricing { - input_cost_per_token: number; - output_cost_per_token: number; - cache_creation_input_token_cost?: number; - cache_read_input_token_cost?: number; - input_cost_per_token_above_200k_tokens?: number; - output_cost_per_token_above_200k_tokens?: number; - cache_creation_input_token_cost_above_200k_tokens?: number; - cache_read_input_token_cost_above_200k_tokens?: number; - [key: string]: unknown; -} - -const TIER_THRESHOLD = 200_000; - -// Cache pricing data in memory (loaded once on first use) -let pricingCache: Record | null = null; - -/** - * Load pricing data from resources directory. - * Uses electron-vite resource directory pattern: - * - Development: resources/pricing.json (project root) - * - Production: process.resourcesPath/pricing.json - */ -function loadPricingData(): Record { - if (pricingCache !== null) { - return pricingCache; - } - - try { - // Determine if we're in development or production - const isDev = process.env.NODE_ENV === 'development' || !process.resourcesPath; - - let pricingPath: string; - if (isDev) { - // Development: Compiled code is in dist-electron/main/ - // __dirname = /path/to/project/dist-electron/main - // Need to go up 2 levels to reach project root, then into resources/ - pricingPath = path.join(__dirname, '..', '..', 'resources', 'pricing.json'); - } else { - // Production: pricing.json in app's resources directory - pricingPath = path.join(process.resourcesPath, 'pricing.json'); - } - - const data = fs.readFileSync(pricingPath, 'utf-8'); - pricingCache = JSON.parse(data) as Record; - return pricingCache; - } catch (error) { - console.error('Failed to load pricing data:', error); - // Return empty object if pricing data can't be loaded - pricingCache = {}; - return pricingCache; - } -} - -function calculateTieredCost(tokens: number, baseRate: number, tieredRate?: number): number { - if (tokens <= 0) return 0; - if (!tieredRate || tokens <= TIER_THRESHOLD) { - return tokens * baseRate; - } - const costBelow = TIER_THRESHOLD * baseRate; - const costAbove = (tokens - TIER_THRESHOLD) * tieredRate; - return costBelow + costAbove; -} - -function getPricing(modelName: string): ModelPricing | null { - const pricing = loadPricingData(); - - const tryGet = (key: string): ModelPricing | null => { - const entry = pricing[key]; - if ( - entry && - typeof entry === 'object' && - 'input_cost_per_token' in entry && - 'output_cost_per_token' in entry - ) { - return entry as ModelPricing; - } - return null; - }; - - // Try exact match - const exact = tryGet(modelName); - if (exact) return exact; - - // Try lowercase - const lowerName = modelName.toLowerCase(); - const lower = tryGet(lowerName); - if (lower) return lower; - - // Try case-insensitive search - for (const key of Object.keys(pricing)) { - if (key.toLowerCase() === lowerName) { - const match = tryGet(key); - if (match) return match; - } - } - - return null; -} - // ============================================================================= // Metrics Calculation // ============================================================================= @@ -335,7 +229,6 @@ export function calculateMetrics(messages: ParsedMessage[]): SessionMetrics { let outputTokens = 0; let cacheReadTokens = 0; let cacheCreationTokens = 0; - let modelName: string | undefined; // Get timestamps for duration (loop instead of Math.min/max spread to avoid stack overflow on large sessions) const timestamps = messages.map((m) => m.timestamp.getTime()).filter((t) => !isNaN(t)); @@ -366,36 +259,14 @@ export function calculateMetrics(messages: ParsedMessage[]): SessionMetrics { cacheReadTokens += msgCacheReadTokens; cacheCreationTokens += msgCacheCreationTokens; - // Calculate cost for this message if we have pricing data - if (msg.model && !modelName) { - modelName = msg.model; - } - if (msg.model) { - const pricing = getPricing(msg.model); - if (pricing) { - const inputCost = calculateTieredCost( - msgInputTokens, - pricing.input_cost_per_token, - pricing.input_cost_per_token_above_200k_tokens - ); - const outputCost = calculateTieredCost( - msgOutputTokens, - pricing.output_cost_per_token, - pricing.output_cost_per_token_above_200k_tokens - ); - const cacheCreationCost = calculateTieredCost( - msgCacheCreationTokens, - pricing.cache_creation_input_token_cost ?? 0, - pricing.cache_creation_input_token_cost_above_200k_tokens - ); - const cacheReadCost = calculateTieredCost( - msgCacheReadTokens, - pricing.cache_read_input_token_cost ?? 0, - pricing.cache_read_input_token_cost_above_200k_tokens - ); - costUsd += inputCost + outputCost + cacheCreationCost + cacheReadCost; - } + costUsd += calculateMessageCost( + msg.model, + msgInputTokens, + msgOutputTokens, + msgCacheReadTokens, + msgCacheCreationTokens + ); } } } diff --git a/test/main/utils/costCalculation.test.ts b/test/main/utils/costCalculation.test.ts index 42bd4d79..b46f96b9 100644 --- a/test/main/utils/costCalculation.test.ts +++ b/test/main/utils/costCalculation.test.ts @@ -2,44 +2,11 @@ * Tests for cost calculation in jsonl.ts */ -import { describe, it, expect, beforeEach, vi } from 'vitest'; -import * as fs from 'fs'; +import { describe, it, expect } from 'vitest'; import { calculateMetrics } from '@main/utils/jsonl'; import type { ParsedMessage } from '@main/types'; -// Mock fs module -vi.mock('fs'); - describe('Cost Calculation', () => { - // Sample pricing data matching Claude models - const mockPricingData = { - 'claude-3-5-sonnet-20241022': { - input_cost_per_token: 0.000003, - output_cost_per_token: 0.000015, - cache_creation_input_token_cost: 0.00000375, - cache_read_input_token_cost: 0.0000003, - input_cost_per_token_above_200k_tokens: 0.000006, - output_cost_per_token_above_200k_tokens: 0.00003, - cache_creation_input_token_cost_above_200k_tokens: 0.0000075, - cache_read_input_token_cost_above_200k_tokens: 0.0000006, - }, - 'claude-3-opus-20240229': { - input_cost_per_token: 0.000015, - output_cost_per_token: 0.000075, - cache_creation_input_token_cost: 0.00001875, - cache_read_input_token_cost: 0.0000015, - }, - }; - - beforeEach(() => { - // Reset modules to clear pricing cache - vi.resetModules(); - - // Mock fs.readFileSync to return our test pricing data - vi.mocked(fs.readFileSync).mockReturnValue(JSON.stringify(mockPricingData)); - vi.mocked(fs.existsSync).mockReturnValue(true); - }); - describe('Basic Cost Calculation', () => { it('should calculate cost for simple token usage', () => { const messages: ParsedMessage[] = [ @@ -166,7 +133,7 @@ describe('Cost Calculation', () => { expect(metrics.costUsd).toBeCloseTo(1.05, 6); }); - it('should use tiered rates for input tokens above 200k threshold', () => { + it('should use base rates for input tokens above 200k when model has no tiered pricing', () => { const messages: ParsedMessage[] = [ { type: 'assistant', @@ -186,13 +153,14 @@ describe('Cost Calculation', () => { const metrics = calculateMetrics(messages); - // Input: (200000 * 0.000003) + (50000 * 0.000006) = 0.6 + 0.3 = 0.9 + // claude-3-5-sonnet-20241022 has no tiered rates in pricing.json, so base rates apply + // Input: 250000 * 0.000003 = 0.75 // Output: 1000 * 0.000015 = 0.015 - // Total: 0.915 - expect(metrics.costUsd).toBeCloseTo(0.915, 6); + // Total: 0.765 + expect(metrics.costUsd).toBeCloseTo(0.765, 6); }); - it('should use tiered rates for output tokens above 200k threshold', () => { + it('should use base rates for output tokens above 200k when model has no tiered pricing', () => { const messages: ParsedMessage[] = [ { type: 'assistant', @@ -212,13 +180,14 @@ describe('Cost Calculation', () => { const metrics = calculateMetrics(messages); + // No tiered rates, so base rates for all tokens // Input: 1000 * 0.000003 = 0.003 - // Output: (200000 * 0.000015) + (50000 * 0.00003) = 3.0 + 1.5 = 4.5 - // Total: 4.503 - expect(metrics.costUsd).toBeCloseTo(4.503, 6); + // Output: 250000 * 0.000015 = 3.75 + // Total: 3.753 + expect(metrics.costUsd).toBeCloseTo(3.753, 6); }); - it('should use tiered rates for cache tokens above 200k threshold', () => { + it('should use base rates for cache tokens above 200k when model has no tiered pricing', () => { const messages: ParsedMessage[] = [ { type: 'assistant', @@ -240,12 +209,13 @@ describe('Cost Calculation', () => { const metrics = calculateMetrics(messages); + // No tiered rates for this model, so base rates apply // Input: 1000 * 0.000003 = 0.003 // Output: 1000 * 0.000015 = 0.015 - // Cache creation: (200000 * 0.00000375) + (50000 * 0.0000075) = 0.75 + 0.375 = 1.125 - // Cache read: (200000 * 0.0000003) + (50000 * 0.0000006) = 0.06 + 0.03 = 0.09 - // Total: 1.233 - expect(metrics.costUsd).toBeCloseTo(1.233, 6); + // Cache creation: 250000 * 0.00000375 = 0.9375 + // Cache read: 250000 * 0.0000003 = 0.075 + // Total: 1.0305 + expect(metrics.costUsd).toBeCloseTo(1.0305, 6); }); it('should handle model without tiered pricing', () => { @@ -274,6 +244,34 @@ describe('Cost Calculation', () => { // Total: 22.5 expect(metrics.costUsd).toBeCloseTo(22.5, 6); }); + + it('should use tiered rates for a model that has them (claude-4-sonnet)', () => { + const messages: ParsedMessage[] = [ + { + type: 'assistant', + uuid: 'msg-1', + timestamp: new Date(), + content: [], + model: 'claude-4-sonnet-20250514', + usage: { + input_tokens: 250_000, + output_tokens: 1_000, + }, + toolCalls: [], + toolResults: [], + isSidechain: false, + }, + ]; + + const metrics = calculateMetrics(messages); + + // claude-4-sonnet has tiered rates: + // input base=0.000003, above_200k=0.000006 + // Input: (200000 * 0.000003) + (50000 * 0.000006) = 0.6 + 0.3 = 0.9 + // Output: 1000 * 0.000015 = 0.015 + // Total: 0.915 + expect(metrics.costUsd).toBeCloseTo(0.915, 6); + }); }); describe('Multiple Messages', () => { @@ -405,48 +403,6 @@ describe('Cost Calculation', () => { const metrics = calculateMetrics(messages); expect(metrics.costUsd).toBe(0); }); - - it('should handle pricing data load failure gracefully', async () => { - // Suppress expected console.error for this test - const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); - - // Reset modules to clear the pricing cache - vi.resetModules(); - - // Mock fs to throw error BEFORE importing calculateMetrics - vi.mocked(fs.readFileSync).mockImplementation(() => { - throw new Error('File not found'); - }); - - // Re-import calculateMetrics to get fresh instance with cleared cache - const { calculateMetrics: freshCalculateMetrics } = await import('@main/utils/jsonl'); - - const messages: ParsedMessage[] = [ - { - type: 'assistant', - uuid: 'msg-1', - timestamp: new Date(), - content: [], - model: 'claude-3-5-sonnet-20241022', - usage: { - input_tokens: 1000, - output_tokens: 500, - }, - toolCalls: [], - toolResults: [], - isSidechain: false, - }, - ]; - - const metrics = freshCalculateMetrics(messages); - expect(metrics.costUsd).toBe(0); - - // Verify that console.error was called (error was logged) - expect(consoleErrorSpy).toHaveBeenCalled(); - - // Restore console.error - consoleErrorSpy.mockRestore(); - }); }); describe('Model Name Lookup', () => { @@ -535,7 +491,7 @@ describe('Cost Calculation', () => { expect(metrics.costUsd).not.toBeCloseTo(incorrectAggregatedCost, 2); }); - it('should apply tiered rates when individual messages exceed 200k', () => { + it('should use base rates when individual messages exceed 200k and model has no tiered rates', () => { const messages: ParsedMessage[] = [ { type: 'assistant', @@ -556,11 +512,9 @@ describe('Cost Calculation', () => { const metrics = calculateMetrics(messages); - // Single message with 300k cache_read tokens - // First 200k: 200,000 * 0.0000003 = $0.06 - // Remaining 100k: 100,000 * 0.0000006 = $0.06 - // Total: $0.12 - const expectedCost = 200000 * 0.0000003 + 100000 * 0.0000006; + // No tiered rates for this model, so all 300k at base rate + // 300,000 * 0.0000003 = $0.09 + const expectedCost = 300000 * 0.0000003; expect(metrics.costUsd).toBeCloseTo(expectedCost, 6); }); });