Merge pull request #73 from holstein13/feat/unify-cost-calculation

Unify cost calculation with shared pricing module
This commit is contained in:
matt 2026-02-24 11:29:50 +08:00 committed by GitHub
commit 68f16bb717
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 824 additions and 317 deletions

View file

@ -0,0 +1,71 @@
# Unify Cost Calculation — Design Document
**Date:** 2026-02-23
**Branch:** `feat/unify-cost-calculation`
**Related:** PR #60 (Session Analysis Report), PR #65 (Cost Calculation Metric), Issue #72 (Plan Usage Tracking)
## Problem
Cost calculation exists in two places with different pricing data and logic:
1. **Main process** (`src/main/utils/jsonl.ts`): Uses LiteLLM-sourced `pricing.json` (206 models, tiered 200k-token pricing). Populates `SessionMetrics.costUsd` for the chat UI.
2. **Renderer** (`src/renderer/utils/sessionAnalyzer.ts`): Uses a hardcoded 6-model pricing table with no tiered pricing. Generates per-model cost breakdown for the Session Report.
The two systems can produce different cost numbers for the same session and will drift further as models change.
## Solution
Create a single shared pricing module that both processes import.
### New Module: `src/shared/utils/pricing.ts`
**Exports:**
| Export | Description |
|--------|-------------|
| `ModelPricing` | Interface for per-model rates (input, output, cache read, cache creation, plus tiered variants) |
| `getPricing(modelName: string): ModelPricing \| null` | Model lookup: exact match, lowercase, case-insensitive scan |
| `calculateTieredCost(tokens, baseRate, tieredRate?): number` | Applies 200k-token tier threshold |
| `calculateMessageCost(model, inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens): number` | Computes cost for a single API call |
**Pricing data:** Static `import pricingData from '../../../resources/pricing.json'` with `resolveJsonModule: true`. Replaces `fs.readFileSync` runtime loading.
### Consumer Changes
**`src/main/utils/jsonl.ts`:**
- Remove: `ModelPricing` interface, `loadPricingData()`, `calculateTieredCost()`, `getPricing()`, `fs`/`path` imports
- Keep: `calculateMetrics()` function
- Change: inline cost loop body → call `calculateMessageCost()` from `@shared/utils/pricing`
**`src/renderer/utils/sessionAnalyzer.ts`:**
- Remove: `MODEL_PRICING` table (~40 lines), `DEFAULT_PRICING`, local `getPricing()`, local `costUsd()`
- Change: calls at lines 476 and 900 → `calculateMessageCost()` from `@shared/utils/pricing`
**Tests:**
- `test/main/utils/costCalculation.test.ts` → update to test shared module functions
- `test/renderer/utils/sessionAnalyzer.test.ts` → mock `@shared/utils/pricing` instead of local functions
- New `test/shared/utils/pricing.test.ts` for the shared module
### Pricing JSON Import Strategy
- `pricing.json` stays in `resources/` for Electron's `extraResources` packaging
- Both Vite (renderer) and electron-vite (main) resolve the JSON import at compile time
- Remove the `fs.readFileSync` dev/prod path logic from `jsonl.ts`
### Fallback Behavior
- `getPricing()` returns `null` for unknown models
- `calculateMessageCost()` returns `0` for unknown models (matches current `jsonl.ts` behavior)
- Session analyzer callers can apply a default if needed
### What Changes for Users
- Report costs become more accurate (tiered pricing, 206 models instead of 6)
- Cost numbers between chat view and Session Report now agree exactly
- Small UI change: Visible Context header adds a "parent only · view full cost" action when available
## Out of Scope
- Plan usage tracking (see Issue #72 — pending community feedback)
- New UI surfaces for cost display
- Changes to the `costFormatting.ts` shared utility

View file

@ -0,0 +1,435 @@
# Unify Cost Calculation Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** Replace dual cost calculation systems with a single shared pricing module used by both main and renderer processes.
**Architecture:** Create `src/shared/utils/pricing.ts` that statically imports `resources/pricing.json` and exports all pricing functions. Both `jsonl.ts` (main) and `sessionAnalyzer.ts` (renderer) consume this module instead of maintaining their own pricing logic.
**Tech Stack:** TypeScript, Vitest, electron-vite (resolveJsonModule)
---
## Tasks
### Task 1: Create the shared pricing module with tests
**Files:**
- Create: `src/shared/utils/pricing.ts`
- Test: `test/shared/utils/pricing.test.ts`
**Step 1: Write the failing tests**
Create `test/shared/utils/pricing.test.ts`:
```typescript
import { describe, it, expect } from 'vitest';
import {
getPricing,
calculateTieredCost,
calculateMessageCost,
getDisplayPricing,
} from '@shared/utils/pricing';
describe('Shared Pricing Module', () => {
describe('getPricing', () => {
it('should find pricing by exact model name', () => {
// Use a model known to exist in pricing.json
const pricing = getPricing('claude-3-5-sonnet-20241022');
expect(pricing).not.toBeNull();
expect(pricing!.input_cost_per_token).toBeGreaterThan(0);
expect(pricing!.output_cost_per_token).toBeGreaterThan(0);
});
it('should find pricing case-insensitively', () => {
const pricing = getPricing('Claude-3-5-Sonnet-20241022');
expect(pricing).not.toBeNull();
});
it('should return null for unknown models', () => {
const pricing = getPricing('totally-fake-model-xyz');
expect(pricing).toBeNull();
});
});
describe('calculateTieredCost', () => {
it('should use base rate for tokens below 200k', () => {
const cost = calculateTieredCost(100_000, 0.000003);
expect(cost).toBeCloseTo(0.3, 6);
});
it('should apply tiered rate above 200k', () => {
const cost = calculateTieredCost(250_000, 0.000003, 0.000006);
// (200000 * 0.000003) + (50000 * 0.000006) = 0.6 + 0.3 = 0.9
expect(cost).toBeCloseTo(0.9, 6);
});
it('should use base rate when no tiered rate provided', () => {
const cost = calculateTieredCost(250_000, 0.000015);
expect(cost).toBeCloseTo(3.75, 6);
});
it('should return 0 for zero or negative tokens', () => {
expect(calculateTieredCost(0, 0.000003)).toBe(0);
expect(calculateTieredCost(-100, 0.000003)).toBe(0);
});
});
describe('calculateMessageCost', () => {
it('should compute cost for a known model', () => {
// claude-3-5-sonnet-20241022: input=0.000003, output=0.000015
const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 0, 0);
// (1000 * 0.000003) + (500 * 0.000015) = 0.003 + 0.0075 = 0.0105
expect(cost).toBeCloseTo(0.0105, 6);
});
it('should return 0 for unknown models', () => {
const cost = calculateMessageCost('unknown-model', 1000, 500, 0, 0);
expect(cost).toBe(0);
});
it('should include cache token costs', () => {
const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 300, 200);
expect(cost).toBeGreaterThan(0.0105); // more than just input+output
});
});
describe('getDisplayPricing', () => {
it('should return per-million rates for a known model', () => {
const dp = getDisplayPricing('claude-3-5-sonnet-20241022');
expect(dp).not.toBeNull();
expect(dp!.input).toBeCloseTo(3.0, 1); // $3/M input
expect(dp!.output).toBeCloseTo(15.0, 1); // $15/M output
});
it('should return null for unknown models', () => {
expect(getDisplayPricing('unknown-model')).toBeNull();
});
});
});
```
**Step 2: Run tests to verify they fail**
Run: `pnpm vitest run test/shared/utils/pricing.test.ts`
Expected: FAIL — module does not exist
**Step 3: Create the shared pricing module**
Create `src/shared/utils/pricing.ts`:
```typescript
import pricingData from '../../../resources/pricing.json';
export interface LiteLLMPricing {
input_cost_per_token: number;
output_cost_per_token: number;
cache_creation_input_token_cost?: number;
cache_read_input_token_cost?: number;
input_cost_per_token_above_200k_tokens?: number;
output_cost_per_token_above_200k_tokens?: number;
cache_creation_input_token_cost_above_200k_tokens?: number;
cache_read_input_token_cost_above_200k_tokens?: number;
[key: string]: unknown;
}
export interface DisplayPricing {
input: number;
output: number;
cache_read: number;
cache_creation: number;
}
const TIER_THRESHOLD = 200_000;
const pricingMap = pricingData as Record<string, unknown>;
function tryGetPricing(key: string): LiteLLMPricing | null {
const entry = pricingMap[key];
if (
entry &&
typeof entry === 'object' &&
'input_cost_per_token' in entry &&
'output_cost_per_token' in entry
) {
return entry as LiteLLMPricing;
}
return null;
}
export function getPricing(modelName: string): LiteLLMPricing | null {
const exact = tryGetPricing(modelName);
if (exact) return exact;
const lowerName = modelName.toLowerCase();
const lower = tryGetPricing(lowerName);
if (lower) return lower;
for (const key of Object.keys(pricingMap)) {
if (key.toLowerCase() === lowerName) {
const match = tryGetPricing(key);
if (match) return match;
}
}
return null;
}
export function calculateTieredCost(
tokens: number,
baseRate: number,
tieredRate?: number
): number {
if (tokens <= 0) return 0;
if (!tieredRate || tokens <= TIER_THRESHOLD) {
return tokens * baseRate;
}
const costBelow = TIER_THRESHOLD * baseRate;
const costAbove = (tokens - TIER_THRESHOLD) * tieredRate;
return costBelow + costAbove;
}
export function calculateMessageCost(
modelName: string,
inputTokens: number,
outputTokens: number,
cacheReadTokens: number,
cacheCreationTokens: number
): number {
const pricing = getPricing(modelName);
if (!pricing) return 0;
const inputCost = calculateTieredCost(
inputTokens,
pricing.input_cost_per_token,
pricing.input_cost_per_token_above_200k_tokens
);
const outputCost = calculateTieredCost(
outputTokens,
pricing.output_cost_per_token,
pricing.output_cost_per_token_above_200k_tokens
);
const cacheCreationCost = calculateTieredCost(
cacheCreationTokens,
pricing.cache_creation_input_token_cost ?? 0,
pricing.cache_creation_input_token_cost_above_200k_tokens
);
const cacheReadCost = calculateTieredCost(
cacheReadTokens,
pricing.cache_read_input_token_cost ?? 0,
pricing.cache_read_input_token_cost_above_200k_tokens
);
return inputCost + outputCost + cacheCreationCost + cacheReadCost;
}
export function getDisplayPricing(modelName: string): DisplayPricing | null {
const pricing = getPricing(modelName);
if (!pricing) return null;
return {
input: pricing.input_cost_per_token * 1_000_000,
output: pricing.output_cost_per_token * 1_000_000,
cache_read: (pricing.cache_read_input_token_cost ?? 0) * 1_000_000,
cache_creation: (pricing.cache_creation_input_token_cost ?? 0) * 1_000_000,
};
}
```
**Step 4: Run tests to verify they pass**
Run: `pnpm vitest run test/shared/utils/pricing.test.ts`
Expected: PASS
**Step 5: Commit**
```bash
git add src/shared/utils/pricing.ts test/shared/utils/pricing.test.ts
git commit -m "feat: add shared pricing module with LiteLLM data"
```
---
### Task 2: Wire jsonl.ts to use the shared pricing module
**Files:**
- Modify: `src/main/utils/jsonl.ts:219-400` (remove pricing functions, update calculateMetrics)
- Test: `test/main/utils/costCalculation.test.ts` (update to remove fs mocking)
**Step 1: Update the cost calculation test to use static imports**
The existing tests mock `fs.readFileSync` to provide pricing data. Since the shared module uses a static JSON import, the tests should instead test against real pricing data or mock the shared module.
Update `test/main/utils/costCalculation.test.ts`:
- Remove `import * as fs from 'fs'` and `vi.mock('fs')`
- Remove `mockPricingData` and the `beforeEach` that mocks `fs.readFileSync`
- Update model names in tests to match models that exist in `resources/pricing.json` (the existing `claude-3-5-sonnet-20241022` and `claude-3-opus-20240229` should already be there)
- Update expected cost values to match the actual rates from `pricing.json` (verify they match the existing mock data — they should be identical since the mock was based on real rates)
- Remove the "pricing data load failure" test (line 409-449) — there's no runtime file loading to fail anymore
- Keep all other test cases and assertions as-is
**Step 2: Run updated tests to verify they fail**
Run: `pnpm vitest run test/main/utils/costCalculation.test.ts`
Expected: FAIL — jsonl.ts still has old imports
**Step 3: Update jsonl.ts**
In `src/main/utils/jsonl.ts`:
- Remove lines 219-320: the `fs`/`path` imports, `ModelPricing` interface, `TIER_THRESHOLD`, `pricingCache`, `loadPricingData()`, `calculateTieredCost()`, `getPricing()`
- Add import at top of file: `import { calculateMessageCost } from '@shared/utils/pricing';`
- In `calculateMetrics()` (lines 354-400), replace the inline cost calculation block (lines 374-398) with:
```typescript
if (msg.model) {
costUsd += calculateMessageCost(
msg.model,
msgInputTokens,
msgOutputTokens,
msgCacheReadTokens,
msgCacheCreationTokens
);
}
```
- Remove the unused `modelName` variable (line 338) and the block that sets it (lines 370-372)
**Step 4: Run tests to verify they pass**
Run: `pnpm vitest run test/main/utils/costCalculation.test.ts`
Expected: PASS
**Step 5: Run full test suite to check for regressions**
Run: `pnpm test`
Expected: All tests pass
**Step 6: Commit**
```bash
git add src/main/utils/jsonl.ts test/main/utils/costCalculation.test.ts
git commit -m "refactor: wire jsonl.ts to shared pricing module"
```
---
### Task 3: Wire sessionAnalyzer.ts and CostSection.tsx to use the shared pricing module
**Files:**
- Modify: `src/renderer/utils/sessionAnalyzer.ts:32,60-130` (remove local pricing)
- Modify: `src/renderer/types/sessionReport.ts:23-28` (update ModelPricing type)
- Modify: `src/renderer/components/report/sections/CostSection.tsx:3,10,39-46,204` (update imports and usage)
**Step 1: Run existing session analyzer tests as baseline**
Run: `pnpm vitest run test/renderer/utils/sessionAnalyzer.test.ts`
Expected: PASS (baseline)
**Step 2: Update sessionAnalyzer.ts**
In `src/renderer/utils/sessionAnalyzer.ts`:
- Remove the `ModelPricing` import from `@renderer/types/sessionReport` (line 32)
- Remove lines 60-130: `MODEL_PRICING` table, `DEFAULT_PRICING`, `getPricing()`, `costUsd()`
- Add import: `import { calculateMessageCost, getDisplayPricing } from '@shared/utils/pricing';`
- Export `getDisplayPricing` as `getPricing` for backward compat with CostSection: `export { getDisplayPricing as getPricing } from '@shared/utils/pricing';`
- Replace `costUsd(model, inpTok, outTok, cr, cc)` at line 476 with `calculateMessageCost(model, inpTok, outTok, cr, cc)`
- Replace `costUsd(subagentModel, ...)` at line 900 with `calculateMessageCost(subagentModel, proc.metrics.inputTokens, proc.metrics.outputTokens, proc.metrics.cacheReadTokens, proc.metrics.cacheCreationTokens)`
**Step 3: Update sessionReport.ts ModelPricing type**
In `src/renderer/types/sessionReport.ts`:
- Replace the existing `ModelPricing` interface (lines 23-28) with a re-export from the shared module:
```typescript
export type { DisplayPricing as ModelPricing } from '@shared/utils/pricing';
```
This keeps backward compatibility — `CostSection.tsx` imports `ModelPricing` from here and expects `{ input, output, cache_read, cache_creation }` which matches `DisplayPricing`.
**Step 4: Update CostSection.tsx**
In `src/renderer/components/report/sections/CostSection.tsx`:
- Line 3: Change `import { getPricing } from '@renderer/utils/sessionAnalyzer'` to `import { getPricing } from '@renderer/utils/sessionAnalyzer'` — no change needed if we re-export from sessionAnalyzer. Verify the import still resolves.
- The `ModelPricing` import from `@renderer/types/sessionReport` (line 10) continues to work via the re-export.
- The `CostBreakdownCard` (lines 34-46) uses `pricing.input`, `pricing.output`, etc. as per-million rates — this matches `DisplayPricing` from `getDisplayPricing()`.
**Step 5: Run session analyzer tests**
Run: `pnpm vitest run test/renderer/utils/sessionAnalyzer.test.ts`
Expected: PASS
**Step 6: Run full test suite**
Run: `pnpm test`
Expected: All tests pass
**Step 7: Commit**
```bash
git add src/renderer/utils/sessionAnalyzer.ts src/renderer/types/sessionReport.ts src/renderer/components/report/sections/CostSection.tsx
git commit -m "refactor: wire session analyzer and CostSection to shared pricing"
```
---
### Task 4: Typecheck, lint, and verify the app runs
**Files:**
- No new files — verification only
**Step 1: Run typecheck**
Run: `pnpm typecheck`
Expected: No errors
**Step 2: Run linter**
Run: `pnpm lint:fix`
Expected: Clean or auto-fixed
**Step 3: Run full test suite**
Run: `pnpm test`
Expected: All tests pass
**Step 4: Run the app and verify cost display**
Run: `pnpm dev`
- Open a session with known token usage
- Verify `TokenUsageDisplay` shows cost in the chat view
- Open the Session Report tab and verify cost-by-model breakdown renders
- Verify CostBreakdownCard expands with per-token-type rates
- Confirm chat view cost and report cost show the same total
**Step 5: Commit any fixes from verification**
If any fixes were needed, commit them:
```bash
git add -A
git commit -m "fix: address typecheck/lint issues from cost unification"
```
---
### Task 5: Clean up dead code from package.json extraResources
**Files:**
- Modify: `package.json` (optional — evaluate if `extraResources` for pricing.json is still needed)
**Step 1: Check if pricing.json is still loaded at runtime anywhere**
Search for any remaining `fs.readFileSync` or runtime references to `pricing.json`:
Run: `grep -r "pricing.json" src/`
Expected: Only the static import in `src/shared/utils/pricing.ts`
**Step 2: Evaluate extraResources**
If no runtime file loading remains, the `extraResources` entry for `pricing.json` in `package.json` is dead config. However, removing it is low-risk and low-priority — it just means the file gets copied to the app bundle uselessly. Leave it for now unless it causes issues. Document the decision.
**Step 3: Final commit**
```bash
git add docs/plans/2026-02-23-unify-cost-calculation.md
git commit -m "docs: finalize implementation plan for cost unification"
```

View file

@ -9,6 +9,7 @@
import { isCommandOutputContent, sanitizeDisplayContent } from '@shared/utils/contentSanitizer';
import { createLogger } from '@shared/utils/logger';
import { calculateMessageCost } from '@shared/utils/pricing';
import * as readline from 'readline';
import { LocalFileSystemProvider } from '../services/infrastructure/LocalFileSystemProvider';
@ -212,113 +213,6 @@ function parseMessageType(type?: string): MessageType | null {
}
}
// =============================================================================
// Cost Calculation
// =============================================================================
import * as fs from 'fs';
import * as path from 'path';
interface ModelPricing {
input_cost_per_token: number;
output_cost_per_token: number;
cache_creation_input_token_cost?: number;
cache_read_input_token_cost?: number;
input_cost_per_token_above_200k_tokens?: number;
output_cost_per_token_above_200k_tokens?: number;
cache_creation_input_token_cost_above_200k_tokens?: number;
cache_read_input_token_cost_above_200k_tokens?: number;
[key: string]: unknown;
}
const TIER_THRESHOLD = 200_000;
// Cache pricing data in memory (loaded once on first use)
let pricingCache: Record<string, unknown> | null = null;
/**
* Load pricing data from resources directory.
* Uses electron-vite resource directory pattern:
* - Development: resources/pricing.json (project root)
* - Production: process.resourcesPath/pricing.json
*/
function loadPricingData(): Record<string, unknown> {
if (pricingCache !== null) {
return pricingCache;
}
try {
// Determine if we're in development or production
const isDev = process.env.NODE_ENV === 'development' || !process.resourcesPath;
let pricingPath: string;
if (isDev) {
// Development: Compiled code is in dist-electron/main/
// __dirname = /path/to/project/dist-electron/main
// Need to go up 2 levels to reach project root, then into resources/
pricingPath = path.join(__dirname, '..', '..', 'resources', 'pricing.json');
} else {
// Production: pricing.json in app's resources directory
pricingPath = path.join(process.resourcesPath, 'pricing.json');
}
const data = fs.readFileSync(pricingPath, 'utf-8');
pricingCache = JSON.parse(data) as Record<string, unknown>;
return pricingCache;
} catch (error) {
console.error('Failed to load pricing data:', error);
// Return empty object if pricing data can't be loaded
pricingCache = {};
return pricingCache;
}
}
function calculateTieredCost(tokens: number, baseRate: number, tieredRate?: number): number {
if (tokens <= 0) return 0;
if (!tieredRate || tokens <= TIER_THRESHOLD) {
return tokens * baseRate;
}
const costBelow = TIER_THRESHOLD * baseRate;
const costAbove = (tokens - TIER_THRESHOLD) * tieredRate;
return costBelow + costAbove;
}
function getPricing(modelName: string): ModelPricing | null {
const pricing = loadPricingData();
const tryGet = (key: string): ModelPricing | null => {
const entry = pricing[key];
if (
entry &&
typeof entry === 'object' &&
'input_cost_per_token' in entry &&
'output_cost_per_token' in entry
) {
return entry as ModelPricing;
}
return null;
};
// Try exact match
const exact = tryGet(modelName);
if (exact) return exact;
// Try lowercase
const lowerName = modelName.toLowerCase();
const lower = tryGet(lowerName);
if (lower) return lower;
// Try case-insensitive search
for (const key of Object.keys(pricing)) {
if (key.toLowerCase() === lowerName) {
const match = tryGet(key);
if (match) return match;
}
}
return null;
}
// =============================================================================
// Metrics Calculation
// =============================================================================
@ -335,7 +229,6 @@ export function calculateMetrics(messages: ParsedMessage[]): SessionMetrics {
let outputTokens = 0;
let cacheReadTokens = 0;
let cacheCreationTokens = 0;
let modelName: string | undefined;
// Get timestamps for duration (loop instead of Math.min/max spread to avoid stack overflow on large sessions)
const timestamps = messages.map((m) => m.timestamp.getTime()).filter((t) => !isNaN(t));
@ -366,36 +259,14 @@ export function calculateMetrics(messages: ParsedMessage[]): SessionMetrics {
cacheReadTokens += msgCacheReadTokens;
cacheCreationTokens += msgCacheCreationTokens;
// Calculate cost for this message if we have pricing data
if (msg.model && !modelName) {
modelName = msg.model;
}
if (msg.model) {
const pricing = getPricing(msg.model);
if (pricing) {
const inputCost = calculateTieredCost(
msgInputTokens,
pricing.input_cost_per_token,
pricing.input_cost_per_token_above_200k_tokens
);
const outputCost = calculateTieredCost(
msgOutputTokens,
pricing.output_cost_per_token,
pricing.output_cost_per_token_above_200k_tokens
);
const cacheCreationCost = calculateTieredCost(
msgCacheCreationTokens,
pricing.cache_creation_input_token_cost ?? 0,
pricing.cache_creation_input_token_cost_above_200k_tokens
);
const cacheReadCost = calculateTieredCost(
msgCacheReadTokens,
pricing.cache_read_input_token_cost ?? 0,
pricing.cache_read_input_token_cost_above_200k_tokens
);
costUsd += inputCost + outputCost + cacheCreationCost + cacheReadCost;
}
costUsd += calculateMessageCost(
msg.model,
msgInputTokens,
msgOutputTokens,
msgCacheReadTokens,
msgCacheCreationTokens
);
}
}
}

View file

@ -64,6 +64,7 @@ export const ChatHistory = ({ tabId }: ChatHistoryProps): JSX.Element => {
syncSearchMatchesWithRendered,
selectSearchMatch,
setTabVisibleAIGroup,
openSessionReport,
} = useStore(
useShallow((s) => ({
searchQuery: s.searchQuery,
@ -76,6 +77,7 @@ export const ChatHistory = ({ tabId }: ChatHistoryProps): JSX.Element => {
syncSearchMatchesWithRendered: s.syncSearchMatchesWithRendered,
selectSearchMatch: s.selectSearchMatch,
setTabVisibleAIGroup: s.setTabVisibleAIGroup,
openSessionReport: s.openSessionReport,
}))
);
@ -100,6 +102,14 @@ export const ChatHistory = ({ tabId }: ChatHistoryProps): JSX.Element => {
sessionDetail,
} = tabData;
// Compute combined subagent cost from process metrics
const subagentCostUsd = useMemo(() => {
const processes = sessionDetail?.processes;
if (!processes || processes.length === 0) return undefined;
const total = processes.reduce((sum, p) => sum + (p.metrics.costUsd ?? 0), 0);
return total > 0 ? total : undefined;
}, [sessionDetail?.processes]);
// State for Context button hover (local state OK - doesn't need per-tab isolation)
const [isContextButtonHovered, setIsContextButtonHovered] = useState(false);
@ -872,6 +882,8 @@ export const ChatHistory = ({ tabId }: ChatHistoryProps): JSX.Element => {
onNavigateToUserGroup={handleNavigateToUserGroup}
totalSessionTokens={lastAiGroupTotalTokens}
sessionMetrics={sessionDetail?.metrics}
subagentCostUsd={subagentCostUsd}
onViewReport={effectiveTabId ? () => openSessionReport(effectiveTabId) : undefined}
phaseInfo={sessionPhaseInfo ?? undefined}
selectedPhase={selectedContextPhase}
onPhaseChange={setSelectedContextPhase}

View file

@ -28,7 +28,9 @@ interface SessionContextHeaderProps {
totalTokens: number;
totalSessionTokens?: number;
sessionMetrics?: SessionMetrics;
subagentCostUsd?: number;
onClose?: () => void;
onViewReport?: () => void;
phaseInfo?: ContextPhaseInfo;
selectedPhase: number | null;
onPhaseChange: (phase: number | null) => void;
@ -41,7 +43,9 @@ export const SessionContextHeader = ({
totalTokens,
totalSessionTokens,
sessionMetrics,
subagentCostUsd,
onClose,
onViewReport,
phaseInfo,
selectedPhase,
onPhaseChange,
@ -130,8 +134,30 @@ export const SessionContextHeader = ({
<div className="col-span-2">
<span style={{ color: COLOR_TEXT_MUTED }}>Session Cost: </span>
<span className="font-medium tabular-nums" style={{ color: COLOR_TEXT_SECONDARY }}>
{formatCostUsd(sessionMetrics.costUsd)}
{formatCostUsd(sessionMetrics.costUsd + (subagentCostUsd ?? 0))}
</span>
{subagentCostUsd !== undefined && subagentCostUsd > 0 && (
<span style={{ color: COLOR_TEXT_MUTED }}>
{' ('}
{formatCostUsd(sessionMetrics.costUsd)}
{' parent + '}
{formatCostUsd(subagentCostUsd)}
{' subagents'}
{onViewReport && (
<>
{' · '}
<button
onClick={onViewReport}
className="underline"
style={{ color: COLOR_TEXT_SECONDARY }}
>
details
</button>
</>
)}
{')'}
</span>
)}
</div>
)}
</div>

View file

@ -49,6 +49,8 @@ export const SessionContextPanel = ({
onNavigateToUserGroup,
totalSessionTokens,
sessionMetrics,
subagentCostUsd,
onViewReport,
phaseInfo,
selectedPhase,
onPhaseChange,
@ -192,7 +194,9 @@ export const SessionContextPanel = ({
totalTokens={totalTokens}
totalSessionTokens={totalSessionTokens}
sessionMetrics={sessionMetrics}
subagentCostUsd={subagentCostUsd}
onClose={onClose}
onViewReport={onViewReport}
phaseInfo={phaseInfo}
selectedPhase={selectedPhase}
onPhaseChange={onPhaseChange}

View file

@ -27,6 +27,10 @@ export interface SessionContextPanelProps {
totalSessionTokens?: number;
/** Full session metrics (input, output, cache tokens, cost) */
sessionMetrics?: SessionMetrics;
/** Combined cost of all subagent processes */
subagentCostUsd?: number;
/** Open the Session Report to see full cost breakdown */
onViewReport?: () => void;
/** Phase information for phase selector */
phaseInfo?: ContextPhaseInfo;
/** Currently selected phase (null = current/latest) */

View file

@ -20,12 +20,7 @@ import type {
// Pricing
// =============================================================================
export interface ModelPricing {
input: number;
output: number;
cache_read: number;
cache_creation: number;
}
export type { DisplayPricing as ModelPricing } from '@shared/utils/pricing';
// =============================================================================
// Report Sections

View file

@ -22,6 +22,7 @@ import {
detectModelMismatch,
detectSwitchPattern,
} from '@renderer/utils/reportAssessments';
import { calculateMessageCost } from '@shared/utils/pricing';
import type {
AgentTreeNode,
@ -29,7 +30,6 @@ import type {
GitCommit,
IdleGap,
KeyEvent,
ModelPricing,
ModelSwitch,
ModelTokenStats,
OutOfScopeFindings,
@ -53,81 +53,8 @@ import type {
ToolCall,
} from '@shared/types';
// =============================================================================
// Pricing Table (USD per 1M tokens)
// =============================================================================
const MODEL_PRICING: Record<string, ModelPricing> = {
'opus-4': {
input: 15.0,
output: 75.0,
cache_read: 1.5,
cache_creation: 18.75,
},
'sonnet-4': {
input: 3.0,
output: 15.0,
cache_read: 0.3,
cache_creation: 3.75,
},
'haiku-4': {
input: 0.8,
output: 4.0,
cache_read: 0.08,
cache_creation: 1.0,
},
'opus-3': {
input: 15.0,
output: 75.0,
cache_read: 1.5,
cache_creation: 18.75,
},
'sonnet-3': {
input: 3.0,
output: 15.0,
cache_read: 0.3,
cache_creation: 3.75,
},
'haiku-3': {
input: 0.25,
output: 1.25,
cache_read: 0.03,
cache_creation: 0.3,
},
};
const DEFAULT_PRICING: ModelPricing = {
input: 3.0,
output: 15.0,
cache_read: 0.3,
cache_creation: 3.75,
};
export function getPricing(modelName: string): ModelPricing {
const nameTokens: string[] = modelName.toLowerCase().match(/[a-z0-9]+/g) ?? [];
for (const [key, pricing] of Object.entries(MODEL_PRICING)) {
const keyTokens: string[] = key.match(/[a-z0-9]+/g) ?? [];
if (keyTokens.every((t) => nameTokens.includes(t))) return pricing;
}
return DEFAULT_PRICING;
}
function costUsd(
modelName: string,
inputTok: number,
outputTok: number,
cacheReadTok: number,
cacheCreationTok: number
): number {
const p = getPricing(modelName);
return (
(inputTok * p.input +
outputTok * p.output +
cacheReadTok * p.cache_read +
cacheCreationTok * p.cache_creation) /
1_000_000
);
}
// Re-export getDisplayPricing as getPricing for backward compat with CostSection
export { getDisplayPricing as getPricing } from '@shared/utils/pricing';
// =============================================================================
// Helpers
@ -473,7 +400,7 @@ export function analyzeSession(detail: SessionDetail): SessionReport {
stats.cacheCreation += cc;
stats.cacheRead += cr;
const callCost = costUsd(model, inpTok, outTok, cr, cc);
const callCost = calculateMessageCost(model, inpTok, outTok, cr, cc);
stats.costUsd += callCost;
parentCost += callCost;
@ -897,7 +824,7 @@ export function analyzeSession(detail: SessionDetail): SessionReport {
proc.messages.find((m: ParsedMessage) => m.type === 'assistant' && m.model)?.model ??
'default (inherits parent)';
// Compute cost from subagent token breakdown (proc.metrics.costUsd is not populated upstream)
const computedCost = costUsd(
const computedCost = calculateMessageCost(
subagentModel,
proc.metrics.inputTokens,
proc.metrics.outputTokens,

121
src/shared/utils/pricing.ts Normal file
View file

@ -0,0 +1,121 @@
// eslint-disable-next-line no-restricted-imports -- resources/ is outside src/, no alias available
import pricingData from '../../../resources/pricing.json';
export interface LiteLLMPricing {
input_cost_per_token: number;
output_cost_per_token: number;
cache_creation_input_token_cost?: number;
cache_read_input_token_cost?: number;
input_cost_per_token_above_200k_tokens?: number;
output_cost_per_token_above_200k_tokens?: number;
cache_creation_input_token_cost_above_200k_tokens?: number;
cache_read_input_token_cost_above_200k_tokens?: number;
[key: string]: unknown;
}
export interface DisplayPricing {
input: number;
output: number;
cache_read: number;
cache_creation: number;
}
const TIER_THRESHOLD = 200_000;
const PRICING_MAP = pricingData as Record<string, unknown>;
// Pre-compute lowercase key map for O(1) case-insensitive lookups
const LOWERCASE_KEY_MAP = new Map<string, string>();
for (const key of Object.keys(PRICING_MAP)) {
if (!LOWERCASE_KEY_MAP.has(key.toLowerCase())) {
LOWERCASE_KEY_MAP.set(key.toLowerCase(), key);
}
}
function isLiteLLMPricing(entry: unknown): entry is LiteLLMPricing {
return (
!!entry &&
typeof entry === 'object' &&
'input_cost_per_token' in entry &&
'output_cost_per_token' in entry
);
}
function tryGetPricing(key: string): LiteLLMPricing | null {
const entry = PRICING_MAP[key];
return isLiteLLMPricing(entry) ? entry : null;
}
export function getPricing(modelName: string): LiteLLMPricing | null {
const exact = tryGetPricing(modelName);
if (exact) return exact;
const lowerName = modelName.toLowerCase();
const originalKey = LOWERCASE_KEY_MAP.get(lowerName);
if (originalKey) {
return tryGetPricing(originalKey);
}
return null;
}
export function calculateTieredCost(tokens: number, baseRate: number, tieredRate?: number): number {
if (tokens <= 0) return 0;
if (tieredRate == null || tokens <= TIER_THRESHOLD) {
return tokens * baseRate;
}
const costBelow = TIER_THRESHOLD * baseRate;
const costAbove = (tokens - TIER_THRESHOLD) * tieredRate;
return costBelow + costAbove;
}
export function calculateMessageCost(
modelName: string,
inputTokens: number,
outputTokens: number,
cacheReadTokens: number,
cacheCreationTokens: number
): number {
const pricing = getPricing(modelName);
if (!pricing) {
if (inputTokens > 0 || outputTokens > 0 || cacheReadTokens > 0 || cacheCreationTokens > 0) {
console.warn(`[pricing] No pricing data for model "${modelName}", cost will be $0`);
}
return 0;
}
const inputCost = calculateTieredCost(
inputTokens,
pricing.input_cost_per_token,
pricing.input_cost_per_token_above_200k_tokens
);
const outputCost = calculateTieredCost(
outputTokens,
pricing.output_cost_per_token,
pricing.output_cost_per_token_above_200k_tokens
);
const cacheCreationCost = calculateTieredCost(
cacheCreationTokens,
pricing.cache_creation_input_token_cost ?? 0,
pricing.cache_creation_input_token_cost_above_200k_tokens
);
const cacheReadCost = calculateTieredCost(
cacheReadTokens,
pricing.cache_read_input_token_cost ?? 0,
pricing.cache_read_input_token_cost_above_200k_tokens
);
return inputCost + outputCost + cacheCreationCost + cacheReadCost;
}
export function getDisplayPricing(modelName: string): DisplayPricing | null {
const pricing = getPricing(modelName);
if (!pricing) return null;
return {
input: pricing.input_cost_per_token * 1_000_000,
output: pricing.output_cost_per_token * 1_000_000,
cache_read: (pricing.cache_read_input_token_cost ?? 0) * 1_000_000,
cache_creation: (pricing.cache_creation_input_token_cost ?? 0) * 1_000_000,
};
}

View file

@ -2,44 +2,11 @@
* Tests for cost calculation in jsonl.ts
*/
import { describe, it, expect, beforeEach, vi } from 'vitest';
import * as fs from 'fs';
import { describe, it, expect, vi } from 'vitest';
import { calculateMetrics } from '@main/utils/jsonl';
import type { ParsedMessage } from '@main/types';
// Mock fs module
vi.mock('fs');
describe('Cost Calculation', () => {
// Sample pricing data matching Claude models
const mockPricingData = {
'claude-3-5-sonnet-20241022': {
input_cost_per_token: 0.000003,
output_cost_per_token: 0.000015,
cache_creation_input_token_cost: 0.00000375,
cache_read_input_token_cost: 0.0000003,
input_cost_per_token_above_200k_tokens: 0.000006,
output_cost_per_token_above_200k_tokens: 0.00003,
cache_creation_input_token_cost_above_200k_tokens: 0.0000075,
cache_read_input_token_cost_above_200k_tokens: 0.0000006,
},
'claude-3-opus-20240229': {
input_cost_per_token: 0.000015,
output_cost_per_token: 0.000075,
cache_creation_input_token_cost: 0.00001875,
cache_read_input_token_cost: 0.0000015,
},
};
beforeEach(() => {
// Reset modules to clear pricing cache
vi.resetModules();
// Mock fs.readFileSync to return our test pricing data
vi.mocked(fs.readFileSync).mockReturnValue(JSON.stringify(mockPricingData));
vi.mocked(fs.existsSync).mockReturnValue(true);
});
describe('Basic Cost Calculation', () => {
it('should calculate cost for simple token usage', () => {
const messages: ParsedMessage[] = [
@ -117,6 +84,7 @@ describe('Cost Calculation', () => {
});
it('should return 0 cost when model pricing not found', () => {
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => undefined);
const messages: ParsedMessage[] = [
{
type: 'assistant',
@ -136,6 +104,7 @@ describe('Cost Calculation', () => {
const metrics = calculateMetrics(messages);
expect(metrics.costUsd).toBe(0);
warnSpy.mockRestore();
});
});
@ -166,7 +135,7 @@ describe('Cost Calculation', () => {
expect(metrics.costUsd).toBeCloseTo(1.05, 6);
});
it('should use tiered rates for input tokens above 200k threshold', () => {
it('should use base rates for input tokens above 200k when model has no tiered pricing', () => {
const messages: ParsedMessage[] = [
{
type: 'assistant',
@ -186,13 +155,14 @@ describe('Cost Calculation', () => {
const metrics = calculateMetrics(messages);
// Input: (200000 * 0.000003) + (50000 * 0.000006) = 0.6 + 0.3 = 0.9
// claude-3-5-sonnet-20241022 has no tiered rates in pricing.json, so base rates apply
// Input: 250000 * 0.000003 = 0.75
// Output: 1000 * 0.000015 = 0.015
// Total: 0.915
expect(metrics.costUsd).toBeCloseTo(0.915, 6);
// Total: 0.765
expect(metrics.costUsd).toBeCloseTo(0.765, 6);
});
it('should use tiered rates for output tokens above 200k threshold', () => {
it('should use base rates for output tokens above 200k when model has no tiered pricing', () => {
const messages: ParsedMessage[] = [
{
type: 'assistant',
@ -212,13 +182,14 @@ describe('Cost Calculation', () => {
const metrics = calculateMetrics(messages);
// No tiered rates, so base rates for all tokens
// Input: 1000 * 0.000003 = 0.003
// Output: (200000 * 0.000015) + (50000 * 0.00003) = 3.0 + 1.5 = 4.5
// Total: 4.503
expect(metrics.costUsd).toBeCloseTo(4.503, 6);
// Output: 250000 * 0.000015 = 3.75
// Total: 3.753
expect(metrics.costUsd).toBeCloseTo(3.753, 6);
});
it('should use tiered rates for cache tokens above 200k threshold', () => {
it('should use base rates for cache tokens above 200k when model has no tiered pricing', () => {
const messages: ParsedMessage[] = [
{
type: 'assistant',
@ -240,12 +211,13 @@ describe('Cost Calculation', () => {
const metrics = calculateMetrics(messages);
// No tiered rates for this model, so base rates apply
// Input: 1000 * 0.000003 = 0.003
// Output: 1000 * 0.000015 = 0.015
// Cache creation: (200000 * 0.00000375) + (50000 * 0.0000075) = 0.75 + 0.375 = 1.125
// Cache read: (200000 * 0.0000003) + (50000 * 0.0000006) = 0.06 + 0.03 = 0.09
// Total: 1.233
expect(metrics.costUsd).toBeCloseTo(1.233, 6);
// Cache creation: 250000 * 0.00000375 = 0.9375
// Cache read: 250000 * 0.0000003 = 0.075
// Total: 1.0305
expect(metrics.costUsd).toBeCloseTo(1.0305, 6);
});
it('should handle model without tiered pricing', () => {
@ -274,6 +246,34 @@ describe('Cost Calculation', () => {
// Total: 22.5
expect(metrics.costUsd).toBeCloseTo(22.5, 6);
});
it('should use tiered rates for a model that has them (claude-4-sonnet)', () => {
const messages: ParsedMessage[] = [
{
type: 'assistant',
uuid: 'msg-1',
timestamp: new Date(),
content: [],
model: 'claude-4-sonnet-20250514',
usage: {
input_tokens: 250_000,
output_tokens: 1_000,
},
toolCalls: [],
toolResults: [],
isSidechain: false,
},
];
const metrics = calculateMetrics(messages);
// claude-4-sonnet has tiered rates:
// input base=0.000003, above_200k=0.000006
// Input: (200000 * 0.000003) + (50000 * 0.000006) = 0.6 + 0.3 = 0.9
// Output: 1000 * 0.000015 = 0.015
// Total: 0.915
expect(metrics.costUsd).toBeCloseTo(0.915, 6);
});
});
describe('Multiple Messages', () => {
@ -405,48 +405,6 @@ describe('Cost Calculation', () => {
const metrics = calculateMetrics(messages);
expect(metrics.costUsd).toBe(0);
});
it('should handle pricing data load failure gracefully', async () => {
// Suppress expected console.error for this test
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
// Reset modules to clear the pricing cache
vi.resetModules();
// Mock fs to throw error BEFORE importing calculateMetrics
vi.mocked(fs.readFileSync).mockImplementation(() => {
throw new Error('File not found');
});
// Re-import calculateMetrics to get fresh instance with cleared cache
const { calculateMetrics: freshCalculateMetrics } = await import('@main/utils/jsonl');
const messages: ParsedMessage[] = [
{
type: 'assistant',
uuid: 'msg-1',
timestamp: new Date(),
content: [],
model: 'claude-3-5-sonnet-20241022',
usage: {
input_tokens: 1000,
output_tokens: 500,
},
toolCalls: [],
toolResults: [],
isSidechain: false,
},
];
const metrics = freshCalculateMetrics(messages);
expect(metrics.costUsd).toBe(0);
// Verify that console.error was called (error was logged)
expect(consoleErrorSpy).toHaveBeenCalled();
// Restore console.error
consoleErrorSpy.mockRestore();
});
});
describe('Model Name Lookup', () => {
@ -535,7 +493,7 @@ describe('Cost Calculation', () => {
expect(metrics.costUsd).not.toBeCloseTo(incorrectAggregatedCost, 2);
});
it('should apply tiered rates when individual messages exceed 200k', () => {
it('should use base rates when individual messages exceed 200k and model has no tiered rates', () => {
const messages: ParsedMessage[] = [
{
type: 'assistant',
@ -556,11 +514,9 @@ describe('Cost Calculation', () => {
const metrics = calculateMetrics(messages);
// Single message with 300k cache_read tokens
// First 200k: 200,000 * 0.0000003 = $0.06
// Remaining 100k: 100,000 * 0.0000006 = $0.06
// Total: $0.12
const expectedCost = 200000 * 0.0000003 + 100000 * 0.0000006;
// No tiered rates for this model, so all 300k at base rate
// 300,000 * 0.0000003 = $0.09
const expectedCost = 300000 * 0.0000003;
expect(metrics.costUsd).toBeCloseTo(expectedCost, 6);
});
});

View file

@ -0,0 +1,85 @@
import { describe, it, expect, vi } from 'vitest';
import {
getPricing,
calculateTieredCost,
calculateMessageCost,
getDisplayPricing,
} from '@shared/utils/pricing';
describe('Shared Pricing Module', () => {
describe('getPricing', () => {
it('should find pricing by exact model name', () => {
const pricing = getPricing('claude-3-5-sonnet-20241022');
expect(pricing).not.toBeNull();
expect(pricing!.input_cost_per_token).toBeGreaterThan(0);
expect(pricing!.output_cost_per_token).toBeGreaterThan(0);
});
it('should find pricing case-insensitively', () => {
const pricing = getPricing('Claude-3-5-Sonnet-20241022');
expect(pricing).not.toBeNull();
});
it('should return null for unknown models', () => {
const pricing = getPricing('totally-fake-model-xyz');
expect(pricing).toBeNull();
});
});
describe('calculateTieredCost', () => {
it('should use base rate for tokens below 200k', () => {
const cost = calculateTieredCost(100_000, 0.000003);
expect(cost).toBeCloseTo(0.3, 6);
});
it('should apply tiered rate above 200k', () => {
const cost = calculateTieredCost(250_000, 0.000003, 0.000006);
expect(cost).toBeCloseTo(0.9, 6);
});
it('should use base rate when no tiered rate provided', () => {
const cost = calculateTieredCost(250_000, 0.000015);
expect(cost).toBeCloseTo(3.75, 6);
});
it('should return 0 for zero or negative tokens', () => {
expect(calculateTieredCost(0, 0.000003)).toBe(0);
expect(calculateTieredCost(-100, 0.000003)).toBe(0);
});
});
describe('calculateMessageCost', () => {
it('should compute cost for a known model', () => {
const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 0, 0);
expect(cost).toBeCloseTo(0.0105, 6);
});
it('should return 0 for unknown models', () => {
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => undefined);
const cost = calculateMessageCost('unknown-model', 1000, 500, 0, 0);
expect(cost).toBe(0);
expect(warnSpy).toHaveBeenCalledWith(
'[pricing] No pricing data for model "unknown-model", cost will be $0'
);
warnSpy.mockRestore();
});
it('should include cache token costs', () => {
const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 300, 200);
expect(cost).toBeGreaterThan(0.0105);
});
});
describe('getDisplayPricing', () => {
it('should return per-million rates for a known model', () => {
const dp = getDisplayPricing('claude-3-5-sonnet-20241022');
expect(dp).not.toBeNull();
expect(dp!.input).toBeCloseTo(3.0, 1);
expect(dp!.output).toBeCloseTo(15.0, 1);
});
it('should return null for unknown models', () => {
expect(getDisplayPricing('unknown-model')).toBeNull();
});
});
});