From 723760938ea11db21eb961bb58de1505667bbfc0 Mon Sep 17 00:00:00 2001 From: KaustubhPatange Date: Sun, 22 Feb 2026 15:18:52 +0530 Subject: [PATCH] feat: add session cost in session header --- .gitignore | 2 +- src/main/utils/jsonl.ts | 77 +++--- src/renderer/components/chat/ChatHistory.tsx | 1 + .../components/SessionContextHeader.tsx | 22 ++ .../chat/SessionContextPanel/index.tsx | 2 + .../chat/SessionContextPanel/types.ts | 8 +- test/main/utils/costCalculation.test.ts | 81 ++++++- test/shared/utils/costFormatting.test.ts | 228 ++++++++++++++++++ 8 files changed, 379 insertions(+), 42 deletions(-) create mode 100644 test/shared/utils/costFormatting.test.ts diff --git a/.gitignore b/.gitignore index 6afcf5ce..28223886 100644 --- a/.gitignore +++ b/.gitignore @@ -46,4 +46,4 @@ temp/ eslint-fix/ -remotion/* +remotion/* \ No newline at end of file diff --git a/src/main/utils/jsonl.ts b/src/main/utils/jsonl.ts index 323bcf8d..13034ea1 100644 --- a/src/main/utils/jsonl.ts +++ b/src/main/utils/jsonl.ts @@ -351,47 +351,58 @@ export function calculateMetrics(messages: ParsedMessage[]): SessionMetrics { } } + // Calculate cost per-message, then sum (tiered pricing applies per-API-call, not to aggregated totals) + let costUsd = 0; + for (const msg of messages) { if (msg.usage) { - inputTokens += msg.usage.input_tokens ?? 0; - outputTokens += msg.usage.output_tokens ?? 0; - cacheReadTokens += msg.usage.cache_read_input_tokens ?? 0; - cacheCreationTokens += msg.usage.cache_creation_input_tokens ?? 0; + const msgInputTokens = msg.usage.input_tokens ?? 0; + const msgOutputTokens = msg.usage.output_tokens ?? 0; + const msgCacheReadTokens = msg.usage.cache_read_input_tokens ?? 0; + const msgCacheCreationTokens = msg.usage.cache_creation_input_tokens ?? 0; + + inputTokens += msgInputTokens; + outputTokens += msgOutputTokens; + cacheReadTokens += msgCacheReadTokens; + cacheCreationTokens += msgCacheCreationTokens; + + // Calculate cost for this message if we have pricing data + if (msg.model && !modelName) { + modelName = msg.model; + } + + if (msg.model) { + const pricing = getPricing(msg.model); + if (pricing) { + const inputCost = calculateTieredCost( + msgInputTokens, + pricing.input_cost_per_token, + pricing.input_cost_per_token_above_200k_tokens + ); + const outputCost = calculateTieredCost( + msgOutputTokens, + pricing.output_cost_per_token, + pricing.output_cost_per_token_above_200k_tokens + ); + const cacheCreationCost = calculateTieredCost( + msgCacheCreationTokens, + pricing.cache_creation_input_token_cost ?? 0, + pricing.cache_creation_input_token_cost_above_200k_tokens + ); + const cacheReadCost = calculateTieredCost( + msgCacheReadTokens, + pricing.cache_read_input_token_cost ?? 0, + pricing.cache_read_input_token_cost_above_200k_tokens + ); + costUsd += inputCost + outputCost + cacheCreationCost + cacheReadCost; + } + } } if (!modelName && msg.model) { modelName = msg.model; } } - // Calculate cost - let costUsd = 0; - if (modelName) { - const pricing = getPricing(modelName); - if (pricing) { - const inputCost = calculateTieredCost( - inputTokens, - pricing.input_cost_per_token, - pricing.input_cost_per_token_above_200k_tokens - ); - const outputCost = calculateTieredCost( - outputTokens, - pricing.output_cost_per_token, - pricing.output_cost_per_token_above_200k_tokens - ); - const cacheCreationCost = calculateTieredCost( - cacheCreationTokens, - pricing.cache_creation_input_token_cost ?? 0, - pricing.cache_creation_input_token_cost_above_200k_tokens - ); - const cacheReadCost = calculateTieredCost( - cacheReadTokens, - pricing.cache_read_input_token_cost ?? 0, - pricing.cache_read_input_token_cost_above_200k_tokens - ); - costUsd = inputCost + outputCost + cacheCreationCost + cacheReadCost; - } - } - return { durationMs: maxTime - minTime, totalTokens: inputTokens + cacheCreationTokens + cacheReadTokens + outputTokens, diff --git a/src/renderer/components/chat/ChatHistory.tsx b/src/renderer/components/chat/ChatHistory.tsx index a6a4acbe..24644b50 100644 --- a/src/renderer/components/chat/ChatHistory.tsx +++ b/src/renderer/components/chat/ChatHistory.tsx @@ -824,6 +824,7 @@ export const ChatHistory = ({ tabId }: ChatHistoryProps): JSX.Element => { onNavigateToTool={handleNavigateToTool} onNavigateToUserGroup={handleNavigateToUserGroup} totalSessionTokens={lastAiGroupTotalTokens} + sessionMetrics={sessionDetail?.metrics} phaseInfo={sessionPhaseInfo ?? undefined} selectedPhase={selectedContextPhase} onPhaseChange={setSelectedContextPhase} diff --git a/src/renderer/components/chat/SessionContextPanel/components/SessionContextHeader.tsx b/src/renderer/components/chat/SessionContextPanel/components/SessionContextHeader.tsx index ae9da2e1..d0791e00 100644 --- a/src/renderer/components/chat/SessionContextPanel/components/SessionContextHeader.tsx +++ b/src/renderer/components/chat/SessionContextPanel/components/SessionContextHeader.tsx @@ -12,6 +12,7 @@ import { COLOR_TEXT_MUTED, COLOR_TEXT_SECONDARY, } from '@renderer/constants/cssVariables'; +import { formatCostUsd } from '@shared/utils/costFormatting'; import { ArrowDownWideNarrow, FileText, LayoutList, X } from 'lucide-react'; import { formatTokens } from '../utils/formatting'; @@ -19,12 +20,14 @@ import { formatTokens } from '../utils/formatting'; import { SessionContextHelpTooltip } from './SessionContextHelpTooltip'; import type { ContextViewMode } from '../types'; +import type { SessionMetrics } from '@main/types'; import type { ContextPhaseInfo } from '@renderer/types/contextInjection'; interface SessionContextHeaderProps { injectionCount: number; totalTokens: number; totalSessionTokens?: number; + sessionMetrics?: SessionMetrics; onClose?: () => void; phaseInfo?: ContextPhaseInfo; selectedPhase: number | null; @@ -37,6 +40,7 @@ export const SessionContextHeader = ({ injectionCount, totalTokens, totalSessionTokens, + sessionMetrics, onClose, phaseInfo, selectedPhase, @@ -115,6 +119,24 @@ export const SessionContextHeader = ({ )} + {/* Session Metrics Breakdown */} + {sessionMetrics && ( +
+ {/* Cost */} + {sessionMetrics.costUsd !== undefined && sessionMetrics.costUsd > 0 && ( +
+ Session Cost: + + {formatCostUsd(sessionMetrics.costUsd)} + +
+ )} +
+ )} + {/* Phase selector - only shown when compactions exist */} {phaseInfo && phaseInfo.phases.length > 1 && (
void; /** Total session tokens (input + output + cache) for comparison */ totalSessionTokens?: number; + /** Full session metrics (input, output, cache tokens, cost) */ + sessionMetrics?: SessionMetrics; /** Phase information for phase selector */ phaseInfo?: ContextPhaseInfo; /** Currently selected phase (null = current/latest) */ diff --git a/test/main/utils/costCalculation.test.ts b/test/main/utils/costCalculation.test.ts index 04a56044..42bd4d79 100644 --- a/test/main/utils/costCalculation.test.ts +++ b/test/main/utils/costCalculation.test.ts @@ -317,7 +317,7 @@ describe('Cost Calculation', () => { expect(metrics.costUsd).toBeCloseTo(0.0315, 6); }); - it('should use first model found when calculating aggregated cost', () => { + it("should calculate cost per-message using each message's model", () => { const messages: ParsedMessage[] = [ { type: 'assistant', @@ -351,10 +351,11 @@ describe('Cost Calculation', () => { const metrics = calculateMetrics(messages); - // Uses first model (sonnet) pricing for all tokens - // Total tokens: 2000 input, 1000 output - // Cost: (2000 * 0.000003) + (1000 * 0.000015) = 0.006 + 0.015 = 0.021 - expect(metrics.costUsd).toBeCloseTo(0.021, 6); + // Each message uses its own model's pricing + // Message 1 (sonnet): (1000 * 0.000003) + (500 * 0.000015) = 0.003 + 0.0075 = 0.0105 + // Message 2 (opus): (1000 * 0.000015) + (500 * 0.000075) = 0.015 + 0.0375 = 0.0525 + // Total cost: 0.0105 + 0.0525 = 0.063 + expect(metrics.costUsd).toBeCloseTo(0.063, 6); }); }); @@ -494,6 +495,76 @@ describe('Cost Calculation', () => { }); }); + describe('Per-Message Tiering', () => { + it('should apply tiered pricing per-message, not to aggregated totals', () => { + // Scenario: Many messages each with cache_read tokens < 200k, + // but aggregated total > 200k + // Each message should use base rates, not tiered rates + const messages: ParsedMessage[] = []; + + // Create 10 messages, each with 50k cache_read tokens (500k total) + for (let i = 0; i < 10; i++) { + messages.push({ + type: 'assistant', + uuid: `msg-${i}`, + timestamp: new Date(), + content: [], + model: 'claude-3-5-sonnet-20241022', + usage: { + input_tokens: 0, + output_tokens: 0, + cache_read_input_tokens: 50000, + }, + toolCalls: [], + toolResults: [], + isSidechain: false, + }); + } + + const metrics = calculateMetrics(messages); + + // Per-message tiering: Each message uses base rate (< 200k threshold) + // Each message: 50,000 * 0.0000003 = $0.015 + // Total: 10 * $0.015 = $0.15 + const expectedCost = 10 * 50000 * 0.0000003; + expect(metrics.costUsd).toBeCloseTo(expectedCost, 6); + + // Verify this is NOT using tiered rate on aggregated total + // If incorrectly aggregated: (200k * 0.0000003) + (300k * 0.0000006) = $0.24 + const incorrectAggregatedCost = 0.24; + expect(metrics.costUsd).not.toBeCloseTo(incorrectAggregatedCost, 2); + }); + + it('should apply tiered rates when individual messages exceed 200k', () => { + const messages: ParsedMessage[] = [ + { + type: 'assistant', + uuid: 'msg-1', + timestamp: new Date(), + content: [], + model: 'claude-3-5-sonnet-20241022', + usage: { + input_tokens: 0, + output_tokens: 0, + cache_read_input_tokens: 300000, // Exceeds 200k threshold + }, + toolCalls: [], + toolResults: [], + isSidechain: false, + }, + ]; + + const metrics = calculateMetrics(messages); + + // Single message with 300k cache_read tokens + // First 200k: 200,000 * 0.0000003 = $0.06 + // Remaining 100k: 100,000 * 0.0000006 = $0.06 + // Total: $0.12 + const expectedCost = 200000 * 0.0000003 + 100000 * 0.0000006; + expect(metrics.costUsd).toBeCloseTo(expectedCost, 6); + }); + }); + describe('Integration with Other Metrics', () => { it('should include cost alongside other session metrics', () => { const messages: ParsedMessage[] = [ diff --git a/test/shared/utils/costFormatting.test.ts b/test/shared/utils/costFormatting.test.ts new file mode 100644 index 00000000..10679009 --- /dev/null +++ b/test/shared/utils/costFormatting.test.ts @@ -0,0 +1,228 @@ +/** + * Tests for cost formatting utilities + */ + +import { describe, it, expect } from 'vitest'; +import { formatCostUsd, formatCostCompact } from '@shared/utils/costFormatting'; + +describe('Cost Formatting', () => { + describe('formatCostUsd', () => { + describe('Zero values', () => { + it('should format zero as $0.00', () => { + expect(formatCostUsd(0)).toBe('$0.00'); + }); + + it('should format negative zero as $0.00', () => { + expect(formatCostUsd(-0)).toBe('$0.00'); + }); + }); + + describe('Standard amounts (>= $0.01)', () => { + it('should format 1 cent with 2 decimal places', () => { + expect(formatCostUsd(0.01)).toBe('$0.01'); + }); + + it('should format 1 dollar with 2 decimal places', () => { + expect(formatCostUsd(1.0)).toBe('$1.00'); + }); + + it('should format dollars and cents', () => { + expect(formatCostUsd(1.23)).toBe('$1.23'); + }); + + it('should format large amounts', () => { + expect(formatCostUsd(999.99)).toBe('$999.99'); + expect(formatCostUsd(1234.56)).toBe('$1234.56'); + }); + + it('should round to 2 decimal places for amounts >= 1 cent', () => { + expect(formatCostUsd(1.234)).toBe('$1.23'); + expect(formatCostUsd(1.235)).toBe('$1.24'); // Rounds up + expect(formatCostUsd(1.999)).toBe('$2.00'); + }); + }); + + describe('Sub-cent amounts ($0.001 - $0.01)', () => { + it('should format 1 tenth of a cent with 3 decimal places', () => { + expect(formatCostUsd(0.001)).toBe('$0.001'); + }); + + it('should format sub-cent amounts with 3 decimal places', () => { + expect(formatCostUsd(0.005)).toBe('$0.005'); + expect(formatCostUsd(0.009)).toBe('$0.009'); + }); + + it('should round to 3 decimal places for sub-cent amounts', () => { + expect(formatCostUsd(0.0012)).toBe('$0.001'); + expect(formatCostUsd(0.0015)).toBe('$0.002'); // Rounds up + expect(formatCostUsd(0.0099)).toBe('$0.010'); + }); + }); + + describe('Very small amounts (< $0.001)', () => { + it('should format tiny amounts with 4 decimal places', () => { + expect(formatCostUsd(0.0001)).toBe('$0.0001'); + expect(formatCostUsd(0.0005)).toBe('$0.0005'); + expect(formatCostUsd(0.0009)).toBe('$0.0009'); + }); + + it('should round to 4 decimal places for tiny amounts', () => { + expect(formatCostUsd(0.00012)).toBe('$0.0001'); + expect(formatCostUsd(0.00016)).toBe('$0.0002'); // Rounds up + expect(formatCostUsd(0.00099)).toBe('$0.0010'); + }); + + it('should handle very tiny amounts', () => { + expect(formatCostUsd(0.000001)).toBe('$0.0000'); + }); + }); + + describe('Edge cases', () => { + it('should handle negative amounts with 4 decimal places', () => { + // Negative numbers don't match >= comparisons, so they use 4 decimals + expect(formatCostUsd(-1.23)).toBe('$-1.2300'); + expect(formatCostUsd(-0.001)).toBe('$-0.0010'); + expect(formatCostUsd(-0.0001)).toBe('$-0.0001'); + }); + + it('should handle very large amounts', () => { + expect(formatCostUsd(1000000)).toBe('$1000000.00'); + }); + + it('should handle precision boundaries', () => { + // Boundary between 2 and 3 decimal places + expect(formatCostUsd(0.01)).toBe('$0.01'); + expect(formatCostUsd(0.00999)).toBe('$0.010'); // Just below threshold, uses 3 decimals + + // Boundary between 3 and 4 decimal places + expect(formatCostUsd(0.001)).toBe('$0.001'); + expect(formatCostUsd(0.00099)).toBe('$0.0010'); // Just below threshold, uses 4 decimals + }); + }); + + describe('Real-world API cost examples', () => { + it('should format typical Claude API costs', () => { + // 1M input tokens at $3.00/M + expect(formatCostUsd(3.0)).toBe('$3.00'); + + // 100k input tokens at $3.00/M + expect(formatCostUsd(0.3)).toBe('$0.30'); + + // 10k cache read tokens at $0.30/M + expect(formatCostUsd(0.003)).toBe('$0.003'); + + // 1k cache read tokens at $0.30/M + expect(formatCostUsd(0.0003)).toBe('$0.0003'); + }); + + it('should format session totals', () => { + // Small session + expect(formatCostUsd(0.15)).toBe('$0.15'); + + // Medium session + expect(formatCostUsd(5.67)).toBe('$5.67'); + + // Large session + expect(formatCostUsd(29.57)).toBe('$29.57'); + }); + }); + }); + + describe('formatCostCompact', () => { + describe('Zero values', () => { + it('should format zero as 0.00', () => { + expect(formatCostCompact(0)).toBe('0.00'); + }); + + it('should format negative zero as 0.00', () => { + expect(formatCostCompact(-0)).toBe('0.00'); + }); + }); + + describe('Standard amounts (>= $0.01)', () => { + it('should format amounts without $ prefix', () => { + expect(formatCostCompact(0.01)).toBe('0.01'); + expect(formatCostCompact(1.0)).toBe('1.00'); + expect(formatCostCompact(1.23)).toBe('1.23'); + }); + + it('should format large amounts', () => { + expect(formatCostCompact(999.99)).toBe('999.99'); + expect(formatCostCompact(1234.56)).toBe('1234.56'); + }); + + it('should round to 2 decimal places', () => { + expect(formatCostCompact(1.234)).toBe('1.23'); + expect(formatCostCompact(1.235)).toBe('1.24'); // Rounds up + expect(formatCostCompact(1.999)).toBe('2.00'); + }); + }); + + describe('Sub-cent amounts ($0.001 - $0.01)', () => { + it('should format sub-cent amounts with 3 decimal places', () => { + expect(formatCostCompact(0.001)).toBe('0.001'); + expect(formatCostCompact(0.005)).toBe('0.005'); + expect(formatCostCompact(0.009)).toBe('0.009'); + }); + + it('should round to 3 decimal places', () => { + expect(formatCostCompact(0.0012)).toBe('0.001'); + expect(formatCostCompact(0.0015)).toBe('0.002'); // Rounds up + expect(formatCostCompact(0.0099)).toBe('0.010'); + }); + }); + + describe('Very small amounts (< $0.001)', () => { + it('should format tiny amounts with 4 decimal places', () => { + expect(formatCostCompact(0.0001)).toBe('0.0001'); + expect(formatCostCompact(0.0005)).toBe('0.0005'); + expect(formatCostCompact(0.0009)).toBe('0.0009'); + }); + + it('should round to 4 decimal places', () => { + expect(formatCostCompact(0.00012)).toBe('0.0001'); + expect(formatCostCompact(0.00016)).toBe('0.0002'); // Rounds up + expect(formatCostCompact(0.00099)).toBe('0.0010'); + }); + }); + + describe('Edge cases', () => { + it('should handle negative amounts with 4 decimal places', () => { + // Negative numbers don't match >= comparisons, so they use 4 decimals + expect(formatCostCompact(-1.23)).toBe('-1.2300'); + expect(formatCostCompact(-0.001)).toBe('-0.0010'); + expect(formatCostCompact(-0.0001)).toBe('-0.0001'); + }); + + it('should handle very large amounts', () => { + expect(formatCostCompact(1000000)).toBe('1000000.00'); + }); + }); + + describe('Comparison with formatCostUsd', () => { + it('should match formatCostUsd except for $ prefix', () => { + const testCases = [0, 0.0001, 0.001, 0.01, 1.23, 999.99]; + + testCases.forEach((cost) => { + const withPrefix = formatCostUsd(cost); + const compact = formatCostCompact(cost); + + // Compact should equal the USD format without the $ + expect(compact).toBe(withPrefix.substring(1)); + }); + }); + }); + + describe('Badge display use cases', () => { + it('should format for badge display', () => { + // Small per-message costs + expect(formatCostCompact(0.0015)).toBe('0.002'); + expect(formatCostCompact(0.01)).toBe('0.01'); + + // Session totals in badges + expect(formatCostCompact(2.5)).toBe('2.50'); + expect(formatCostCompact(15.0)).toBe('15.00'); + }); + }); + }); +});