Merge pull request #73 from holstein13/feat/unify-cost-calculation
Unify cost calculation with shared pricing module
This commit is contained in:
commit
68f16bb717
12 changed files with 824 additions and 317 deletions
71
docs/plans/2026-02-23-unify-cost-calculation-design.md
Normal file
71
docs/plans/2026-02-23-unify-cost-calculation-design.md
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
# Unify Cost Calculation — Design Document
|
||||
|
||||
**Date:** 2026-02-23
|
||||
**Branch:** `feat/unify-cost-calculation`
|
||||
**Related:** PR #60 (Session Analysis Report), PR #65 (Cost Calculation Metric), Issue #72 (Plan Usage Tracking)
|
||||
|
||||
## Problem
|
||||
|
||||
Cost calculation exists in two places with different pricing data and logic:
|
||||
|
||||
1. **Main process** (`src/main/utils/jsonl.ts`): Uses LiteLLM-sourced `pricing.json` (206 models, tiered 200k-token pricing). Populates `SessionMetrics.costUsd` for the chat UI.
|
||||
2. **Renderer** (`src/renderer/utils/sessionAnalyzer.ts`): Uses a hardcoded 6-model pricing table with no tiered pricing. Generates per-model cost breakdown for the Session Report.
|
||||
|
||||
The two systems can produce different cost numbers for the same session and will drift further as models change.
|
||||
|
||||
## Solution
|
||||
|
||||
Create a single shared pricing module that both processes import.
|
||||
|
||||
### New Module: `src/shared/utils/pricing.ts`
|
||||
|
||||
**Exports:**
|
||||
|
||||
| Export | Description |
|
||||
|--------|-------------|
|
||||
| `ModelPricing` | Interface for per-model rates (input, output, cache read, cache creation, plus tiered variants) |
|
||||
| `getPricing(modelName: string): ModelPricing \| null` | Model lookup: exact match, lowercase, case-insensitive scan |
|
||||
| `calculateTieredCost(tokens, baseRate, tieredRate?): number` | Applies 200k-token tier threshold |
|
||||
| `calculateMessageCost(model, inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens): number` | Computes cost for a single API call |
|
||||
|
||||
**Pricing data:** Static `import pricingData from '../../../resources/pricing.json'` with `resolveJsonModule: true`. Replaces `fs.readFileSync` runtime loading.
|
||||
|
||||
### Consumer Changes
|
||||
|
||||
**`src/main/utils/jsonl.ts`:**
|
||||
- Remove: `ModelPricing` interface, `loadPricingData()`, `calculateTieredCost()`, `getPricing()`, `fs`/`path` imports
|
||||
- Keep: `calculateMetrics()` function
|
||||
- Change: inline cost loop body → call `calculateMessageCost()` from `@shared/utils/pricing`
|
||||
|
||||
**`src/renderer/utils/sessionAnalyzer.ts`:**
|
||||
- Remove: `MODEL_PRICING` table (~40 lines), `DEFAULT_PRICING`, local `getPricing()`, local `costUsd()`
|
||||
- Change: calls at lines 476 and 900 → `calculateMessageCost()` from `@shared/utils/pricing`
|
||||
|
||||
**Tests:**
|
||||
- `test/main/utils/costCalculation.test.ts` → update to test shared module functions
|
||||
- `test/renderer/utils/sessionAnalyzer.test.ts` → mock `@shared/utils/pricing` instead of local functions
|
||||
- New `test/shared/utils/pricing.test.ts` for the shared module
|
||||
|
||||
### Pricing JSON Import Strategy
|
||||
|
||||
- `pricing.json` stays in `resources/` for Electron's `extraResources` packaging
|
||||
- Both Vite (renderer) and electron-vite (main) resolve the JSON import at compile time
|
||||
- Remove the `fs.readFileSync` dev/prod path logic from `jsonl.ts`
|
||||
|
||||
### Fallback Behavior
|
||||
|
||||
- `getPricing()` returns `null` for unknown models
|
||||
- `calculateMessageCost()` returns `0` for unknown models (matches current `jsonl.ts` behavior)
|
||||
- Session analyzer callers can apply a default if needed
|
||||
|
||||
### What Changes for Users
|
||||
|
||||
- Report costs become more accurate (tiered pricing, 206 models instead of 6)
|
||||
- Cost numbers between chat view and Session Report now agree exactly
|
||||
- Small UI change: Visible Context header adds a "parent only · view full cost" action when available
|
||||
|
||||
## Out of Scope
|
||||
|
||||
- Plan usage tracking (see Issue #72 — pending community feedback)
|
||||
- New UI surfaces for cost display
|
||||
- Changes to the `costFormatting.ts` shared utility
|
||||
435
docs/plans/2026-02-23-unify-cost-calculation.md
Normal file
435
docs/plans/2026-02-23-unify-cost-calculation.md
Normal file
|
|
@ -0,0 +1,435 @@
|
|||
# Unify Cost Calculation Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Replace dual cost calculation systems with a single shared pricing module used by both main and renderer processes.
|
||||
|
||||
**Architecture:** Create `src/shared/utils/pricing.ts` that statically imports `resources/pricing.json` and exports all pricing functions. Both `jsonl.ts` (main) and `sessionAnalyzer.ts` (renderer) consume this module instead of maintaining their own pricing logic.
|
||||
|
||||
**Tech Stack:** TypeScript, Vitest, electron-vite (resolveJsonModule)
|
||||
|
||||
---
|
||||
|
||||
## Tasks
|
||||
|
||||
### Task 1: Create the shared pricing module with tests
|
||||
|
||||
**Files:**
|
||||
- Create: `src/shared/utils/pricing.ts`
|
||||
- Test: `test/shared/utils/pricing.test.ts`
|
||||
|
||||
**Step 1: Write the failing tests**
|
||||
|
||||
Create `test/shared/utils/pricing.test.ts`:
|
||||
|
||||
```typescript
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
getPricing,
|
||||
calculateTieredCost,
|
||||
calculateMessageCost,
|
||||
getDisplayPricing,
|
||||
} from '@shared/utils/pricing';
|
||||
|
||||
describe('Shared Pricing Module', () => {
|
||||
describe('getPricing', () => {
|
||||
it('should find pricing by exact model name', () => {
|
||||
// Use a model known to exist in pricing.json
|
||||
const pricing = getPricing('claude-3-5-sonnet-20241022');
|
||||
expect(pricing).not.toBeNull();
|
||||
expect(pricing!.input_cost_per_token).toBeGreaterThan(0);
|
||||
expect(pricing!.output_cost_per_token).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should find pricing case-insensitively', () => {
|
||||
const pricing = getPricing('Claude-3-5-Sonnet-20241022');
|
||||
expect(pricing).not.toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for unknown models', () => {
|
||||
const pricing = getPricing('totally-fake-model-xyz');
|
||||
expect(pricing).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateTieredCost', () => {
|
||||
it('should use base rate for tokens below 200k', () => {
|
||||
const cost = calculateTieredCost(100_000, 0.000003);
|
||||
expect(cost).toBeCloseTo(0.3, 6);
|
||||
});
|
||||
|
||||
it('should apply tiered rate above 200k', () => {
|
||||
const cost = calculateTieredCost(250_000, 0.000003, 0.000006);
|
||||
// (200000 * 0.000003) + (50000 * 0.000006) = 0.6 + 0.3 = 0.9
|
||||
expect(cost).toBeCloseTo(0.9, 6);
|
||||
});
|
||||
|
||||
it('should use base rate when no tiered rate provided', () => {
|
||||
const cost = calculateTieredCost(250_000, 0.000015);
|
||||
expect(cost).toBeCloseTo(3.75, 6);
|
||||
});
|
||||
|
||||
it('should return 0 for zero or negative tokens', () => {
|
||||
expect(calculateTieredCost(0, 0.000003)).toBe(0);
|
||||
expect(calculateTieredCost(-100, 0.000003)).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateMessageCost', () => {
|
||||
it('should compute cost for a known model', () => {
|
||||
// claude-3-5-sonnet-20241022: input=0.000003, output=0.000015
|
||||
const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 0, 0);
|
||||
// (1000 * 0.000003) + (500 * 0.000015) = 0.003 + 0.0075 = 0.0105
|
||||
expect(cost).toBeCloseTo(0.0105, 6);
|
||||
});
|
||||
|
||||
it('should return 0 for unknown models', () => {
|
||||
const cost = calculateMessageCost('unknown-model', 1000, 500, 0, 0);
|
||||
expect(cost).toBe(0);
|
||||
});
|
||||
|
||||
it('should include cache token costs', () => {
|
||||
const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 300, 200);
|
||||
expect(cost).toBeGreaterThan(0.0105); // more than just input+output
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDisplayPricing', () => {
|
||||
it('should return per-million rates for a known model', () => {
|
||||
const dp = getDisplayPricing('claude-3-5-sonnet-20241022');
|
||||
expect(dp).not.toBeNull();
|
||||
expect(dp!.input).toBeCloseTo(3.0, 1); // $3/M input
|
||||
expect(dp!.output).toBeCloseTo(15.0, 1); // $15/M output
|
||||
});
|
||||
|
||||
it('should return null for unknown models', () => {
|
||||
expect(getDisplayPricing('unknown-model')).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
**Step 2: Run tests to verify they fail**
|
||||
|
||||
Run: `pnpm vitest run test/shared/utils/pricing.test.ts`
|
||||
Expected: FAIL — module does not exist
|
||||
|
||||
**Step 3: Create the shared pricing module**
|
||||
|
||||
Create `src/shared/utils/pricing.ts`:
|
||||
|
||||
```typescript
|
||||
import pricingData from '../../../resources/pricing.json';
|
||||
|
||||
export interface LiteLLMPricing {
|
||||
input_cost_per_token: number;
|
||||
output_cost_per_token: number;
|
||||
cache_creation_input_token_cost?: number;
|
||||
cache_read_input_token_cost?: number;
|
||||
input_cost_per_token_above_200k_tokens?: number;
|
||||
output_cost_per_token_above_200k_tokens?: number;
|
||||
cache_creation_input_token_cost_above_200k_tokens?: number;
|
||||
cache_read_input_token_cost_above_200k_tokens?: number;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface DisplayPricing {
|
||||
input: number;
|
||||
output: number;
|
||||
cache_read: number;
|
||||
cache_creation: number;
|
||||
}
|
||||
|
||||
const TIER_THRESHOLD = 200_000;
|
||||
|
||||
const pricingMap = pricingData as Record<string, unknown>;
|
||||
|
||||
function tryGetPricing(key: string): LiteLLMPricing | null {
|
||||
const entry = pricingMap[key];
|
||||
if (
|
||||
entry &&
|
||||
typeof entry === 'object' &&
|
||||
'input_cost_per_token' in entry &&
|
||||
'output_cost_per_token' in entry
|
||||
) {
|
||||
return entry as LiteLLMPricing;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function getPricing(modelName: string): LiteLLMPricing | null {
|
||||
const exact = tryGetPricing(modelName);
|
||||
if (exact) return exact;
|
||||
|
||||
const lowerName = modelName.toLowerCase();
|
||||
const lower = tryGetPricing(lowerName);
|
||||
if (lower) return lower;
|
||||
|
||||
for (const key of Object.keys(pricingMap)) {
|
||||
if (key.toLowerCase() === lowerName) {
|
||||
const match = tryGetPricing(key);
|
||||
if (match) return match;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
export function calculateTieredCost(
|
||||
tokens: number,
|
||||
baseRate: number,
|
||||
tieredRate?: number
|
||||
): number {
|
||||
if (tokens <= 0) return 0;
|
||||
if (!tieredRate || tokens <= TIER_THRESHOLD) {
|
||||
return tokens * baseRate;
|
||||
}
|
||||
const costBelow = TIER_THRESHOLD * baseRate;
|
||||
const costAbove = (tokens - TIER_THRESHOLD) * tieredRate;
|
||||
return costBelow + costAbove;
|
||||
}
|
||||
|
||||
export function calculateMessageCost(
|
||||
modelName: string,
|
||||
inputTokens: number,
|
||||
outputTokens: number,
|
||||
cacheReadTokens: number,
|
||||
cacheCreationTokens: number
|
||||
): number {
|
||||
const pricing = getPricing(modelName);
|
||||
if (!pricing) return 0;
|
||||
|
||||
const inputCost = calculateTieredCost(
|
||||
inputTokens,
|
||||
pricing.input_cost_per_token,
|
||||
pricing.input_cost_per_token_above_200k_tokens
|
||||
);
|
||||
const outputCost = calculateTieredCost(
|
||||
outputTokens,
|
||||
pricing.output_cost_per_token,
|
||||
pricing.output_cost_per_token_above_200k_tokens
|
||||
);
|
||||
const cacheCreationCost = calculateTieredCost(
|
||||
cacheCreationTokens,
|
||||
pricing.cache_creation_input_token_cost ?? 0,
|
||||
pricing.cache_creation_input_token_cost_above_200k_tokens
|
||||
);
|
||||
const cacheReadCost = calculateTieredCost(
|
||||
cacheReadTokens,
|
||||
pricing.cache_read_input_token_cost ?? 0,
|
||||
pricing.cache_read_input_token_cost_above_200k_tokens
|
||||
);
|
||||
|
||||
return inputCost + outputCost + cacheCreationCost + cacheReadCost;
|
||||
}
|
||||
|
||||
export function getDisplayPricing(modelName: string): DisplayPricing | null {
|
||||
const pricing = getPricing(modelName);
|
||||
if (!pricing) return null;
|
||||
|
||||
return {
|
||||
input: pricing.input_cost_per_token * 1_000_000,
|
||||
output: pricing.output_cost_per_token * 1_000_000,
|
||||
cache_read: (pricing.cache_read_input_token_cost ?? 0) * 1_000_000,
|
||||
cache_creation: (pricing.cache_creation_input_token_cost ?? 0) * 1_000_000,
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Run tests to verify they pass**
|
||||
|
||||
Run: `pnpm vitest run test/shared/utils/pricing.test.ts`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add src/shared/utils/pricing.ts test/shared/utils/pricing.test.ts
|
||||
git commit -m "feat: add shared pricing module with LiteLLM data"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Wire jsonl.ts to use the shared pricing module
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/main/utils/jsonl.ts:219-400` (remove pricing functions, update calculateMetrics)
|
||||
- Test: `test/main/utils/costCalculation.test.ts` (update to remove fs mocking)
|
||||
|
||||
**Step 1: Update the cost calculation test to use static imports**
|
||||
|
||||
The existing tests mock `fs.readFileSync` to provide pricing data. Since the shared module uses a static JSON import, the tests should instead test against real pricing data or mock the shared module.
|
||||
|
||||
Update `test/main/utils/costCalculation.test.ts`:
|
||||
- Remove `import * as fs from 'fs'` and `vi.mock('fs')`
|
||||
- Remove `mockPricingData` and the `beforeEach` that mocks `fs.readFileSync`
|
||||
- Update model names in tests to match models that exist in `resources/pricing.json` (the existing `claude-3-5-sonnet-20241022` and `claude-3-opus-20240229` should already be there)
|
||||
- Update expected cost values to match the actual rates from `pricing.json` (verify they match the existing mock data — they should be identical since the mock was based on real rates)
|
||||
- Remove the "pricing data load failure" test (line 409-449) — there's no runtime file loading to fail anymore
|
||||
- Keep all other test cases and assertions as-is
|
||||
|
||||
**Step 2: Run updated tests to verify they fail**
|
||||
|
||||
Run: `pnpm vitest run test/main/utils/costCalculation.test.ts`
|
||||
Expected: FAIL — jsonl.ts still has old imports
|
||||
|
||||
**Step 3: Update jsonl.ts**
|
||||
|
||||
In `src/main/utils/jsonl.ts`:
|
||||
- Remove lines 219-320: the `fs`/`path` imports, `ModelPricing` interface, `TIER_THRESHOLD`, `pricingCache`, `loadPricingData()`, `calculateTieredCost()`, `getPricing()`
|
||||
- Add import at top of file: `import { calculateMessageCost } from '@shared/utils/pricing';`
|
||||
- In `calculateMetrics()` (lines 354-400), replace the inline cost calculation block (lines 374-398) with:
|
||||
|
||||
```typescript
|
||||
if (msg.model) {
|
||||
costUsd += calculateMessageCost(
|
||||
msg.model,
|
||||
msgInputTokens,
|
||||
msgOutputTokens,
|
||||
msgCacheReadTokens,
|
||||
msgCacheCreationTokens
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
- Remove the unused `modelName` variable (line 338) and the block that sets it (lines 370-372)
|
||||
|
||||
**Step 4: Run tests to verify they pass**
|
||||
|
||||
Run: `pnpm vitest run test/main/utils/costCalculation.test.ts`
|
||||
Expected: PASS
|
||||
|
||||
**Step 5: Run full test suite to check for regressions**
|
||||
|
||||
Run: `pnpm test`
|
||||
Expected: All tests pass
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add src/main/utils/jsonl.ts test/main/utils/costCalculation.test.ts
|
||||
git commit -m "refactor: wire jsonl.ts to shared pricing module"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Wire sessionAnalyzer.ts and CostSection.tsx to use the shared pricing module
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/renderer/utils/sessionAnalyzer.ts:32,60-130` (remove local pricing)
|
||||
- Modify: `src/renderer/types/sessionReport.ts:23-28` (update ModelPricing type)
|
||||
- Modify: `src/renderer/components/report/sections/CostSection.tsx:3,10,39-46,204` (update imports and usage)
|
||||
|
||||
**Step 1: Run existing session analyzer tests as baseline**
|
||||
|
||||
Run: `pnpm vitest run test/renderer/utils/sessionAnalyzer.test.ts`
|
||||
Expected: PASS (baseline)
|
||||
|
||||
**Step 2: Update sessionAnalyzer.ts**
|
||||
|
||||
In `src/renderer/utils/sessionAnalyzer.ts`:
|
||||
- Remove the `ModelPricing` import from `@renderer/types/sessionReport` (line 32)
|
||||
- Remove lines 60-130: `MODEL_PRICING` table, `DEFAULT_PRICING`, `getPricing()`, `costUsd()`
|
||||
- Add import: `import { calculateMessageCost, getDisplayPricing } from '@shared/utils/pricing';`
|
||||
- Export `getDisplayPricing` as `getPricing` for backward compat with CostSection: `export { getDisplayPricing as getPricing } from '@shared/utils/pricing';`
|
||||
- Replace `costUsd(model, inpTok, outTok, cr, cc)` at line 476 with `calculateMessageCost(model, inpTok, outTok, cr, cc)`
|
||||
- Replace `costUsd(subagentModel, ...)` at line 900 with `calculateMessageCost(subagentModel, proc.metrics.inputTokens, proc.metrics.outputTokens, proc.metrics.cacheReadTokens, proc.metrics.cacheCreationTokens)`
|
||||
|
||||
**Step 3: Update sessionReport.ts ModelPricing type**
|
||||
|
||||
In `src/renderer/types/sessionReport.ts`:
|
||||
- Replace the existing `ModelPricing` interface (lines 23-28) with a re-export from the shared module:
|
||||
|
||||
```typescript
|
||||
export type { DisplayPricing as ModelPricing } from '@shared/utils/pricing';
|
||||
```
|
||||
|
||||
This keeps backward compatibility — `CostSection.tsx` imports `ModelPricing` from here and expects `{ input, output, cache_read, cache_creation }` which matches `DisplayPricing`.
|
||||
|
||||
**Step 4: Update CostSection.tsx**
|
||||
|
||||
In `src/renderer/components/report/sections/CostSection.tsx`:
|
||||
- Line 3: Change `import { getPricing } from '@renderer/utils/sessionAnalyzer'` to `import { getPricing } from '@renderer/utils/sessionAnalyzer'` — no change needed if we re-export from sessionAnalyzer. Verify the import still resolves.
|
||||
- The `ModelPricing` import from `@renderer/types/sessionReport` (line 10) continues to work via the re-export.
|
||||
- The `CostBreakdownCard` (lines 34-46) uses `pricing.input`, `pricing.output`, etc. as per-million rates — this matches `DisplayPricing` from `getDisplayPricing()`.
|
||||
|
||||
**Step 5: Run session analyzer tests**
|
||||
|
||||
Run: `pnpm vitest run test/renderer/utils/sessionAnalyzer.test.ts`
|
||||
Expected: PASS
|
||||
|
||||
**Step 6: Run full test suite**
|
||||
|
||||
Run: `pnpm test`
|
||||
Expected: All tests pass
|
||||
|
||||
**Step 7: Commit**
|
||||
|
||||
```bash
|
||||
git add src/renderer/utils/sessionAnalyzer.ts src/renderer/types/sessionReport.ts src/renderer/components/report/sections/CostSection.tsx
|
||||
git commit -m "refactor: wire session analyzer and CostSection to shared pricing"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Typecheck, lint, and verify the app runs
|
||||
|
||||
**Files:**
|
||||
- No new files — verification only
|
||||
|
||||
**Step 1: Run typecheck**
|
||||
|
||||
Run: `pnpm typecheck`
|
||||
Expected: No errors
|
||||
|
||||
**Step 2: Run linter**
|
||||
|
||||
Run: `pnpm lint:fix`
|
||||
Expected: Clean or auto-fixed
|
||||
|
||||
**Step 3: Run full test suite**
|
||||
|
||||
Run: `pnpm test`
|
||||
Expected: All tests pass
|
||||
|
||||
**Step 4: Run the app and verify cost display**
|
||||
|
||||
Run: `pnpm dev`
|
||||
- Open a session with known token usage
|
||||
- Verify `TokenUsageDisplay` shows cost in the chat view
|
||||
- Open the Session Report tab and verify cost-by-model breakdown renders
|
||||
- Verify CostBreakdownCard expands with per-token-type rates
|
||||
- Confirm chat view cost and report cost show the same total
|
||||
|
||||
**Step 5: Commit any fixes from verification**
|
||||
|
||||
If any fixes were needed, commit them:
|
||||
```bash
|
||||
git add -A
|
||||
git commit -m "fix: address typecheck/lint issues from cost unification"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Clean up dead code from package.json extraResources
|
||||
|
||||
**Files:**
|
||||
- Modify: `package.json` (optional — evaluate if `extraResources` for pricing.json is still needed)
|
||||
|
||||
**Step 1: Check if pricing.json is still loaded at runtime anywhere**
|
||||
|
||||
Search for any remaining `fs.readFileSync` or runtime references to `pricing.json`:
|
||||
|
||||
Run: `grep -r "pricing.json" src/`
|
||||
Expected: Only the static import in `src/shared/utils/pricing.ts`
|
||||
|
||||
**Step 2: Evaluate extraResources**
|
||||
|
||||
If no runtime file loading remains, the `extraResources` entry for `pricing.json` in `package.json` is dead config. However, removing it is low-risk and low-priority — it just means the file gets copied to the app bundle uselessly. Leave it for now unless it causes issues. Document the decision.
|
||||
|
||||
**Step 3: Final commit**
|
||||
|
||||
```bash
|
||||
git add docs/plans/2026-02-23-unify-cost-calculation.md
|
||||
git commit -m "docs: finalize implementation plan for cost unification"
|
||||
```
|
||||
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
import { isCommandOutputContent, sanitizeDisplayContent } from '@shared/utils/contentSanitizer';
|
||||
import { createLogger } from '@shared/utils/logger';
|
||||
import { calculateMessageCost } from '@shared/utils/pricing';
|
||||
import * as readline from 'readline';
|
||||
|
||||
import { LocalFileSystemProvider } from '../services/infrastructure/LocalFileSystemProvider';
|
||||
|
|
@ -212,113 +213,6 @@ function parseMessageType(type?: string): MessageType | null {
|
|||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Cost Calculation
|
||||
// =============================================================================
|
||||
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
|
||||
interface ModelPricing {
|
||||
input_cost_per_token: number;
|
||||
output_cost_per_token: number;
|
||||
cache_creation_input_token_cost?: number;
|
||||
cache_read_input_token_cost?: number;
|
||||
input_cost_per_token_above_200k_tokens?: number;
|
||||
output_cost_per_token_above_200k_tokens?: number;
|
||||
cache_creation_input_token_cost_above_200k_tokens?: number;
|
||||
cache_read_input_token_cost_above_200k_tokens?: number;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
const TIER_THRESHOLD = 200_000;
|
||||
|
||||
// Cache pricing data in memory (loaded once on first use)
|
||||
let pricingCache: Record<string, unknown> | null = null;
|
||||
|
||||
/**
|
||||
* Load pricing data from resources directory.
|
||||
* Uses electron-vite resource directory pattern:
|
||||
* - Development: resources/pricing.json (project root)
|
||||
* - Production: process.resourcesPath/pricing.json
|
||||
*/
|
||||
function loadPricingData(): Record<string, unknown> {
|
||||
if (pricingCache !== null) {
|
||||
return pricingCache;
|
||||
}
|
||||
|
||||
try {
|
||||
// Determine if we're in development or production
|
||||
const isDev = process.env.NODE_ENV === 'development' || !process.resourcesPath;
|
||||
|
||||
let pricingPath: string;
|
||||
if (isDev) {
|
||||
// Development: Compiled code is in dist-electron/main/
|
||||
// __dirname = /path/to/project/dist-electron/main
|
||||
// Need to go up 2 levels to reach project root, then into resources/
|
||||
pricingPath = path.join(__dirname, '..', '..', 'resources', 'pricing.json');
|
||||
} else {
|
||||
// Production: pricing.json in app's resources directory
|
||||
pricingPath = path.join(process.resourcesPath, 'pricing.json');
|
||||
}
|
||||
|
||||
const data = fs.readFileSync(pricingPath, 'utf-8');
|
||||
pricingCache = JSON.parse(data) as Record<string, unknown>;
|
||||
return pricingCache;
|
||||
} catch (error) {
|
||||
console.error('Failed to load pricing data:', error);
|
||||
// Return empty object if pricing data can't be loaded
|
||||
pricingCache = {};
|
||||
return pricingCache;
|
||||
}
|
||||
}
|
||||
|
||||
function calculateTieredCost(tokens: number, baseRate: number, tieredRate?: number): number {
|
||||
if (tokens <= 0) return 0;
|
||||
if (!tieredRate || tokens <= TIER_THRESHOLD) {
|
||||
return tokens * baseRate;
|
||||
}
|
||||
const costBelow = TIER_THRESHOLD * baseRate;
|
||||
const costAbove = (tokens - TIER_THRESHOLD) * tieredRate;
|
||||
return costBelow + costAbove;
|
||||
}
|
||||
|
||||
function getPricing(modelName: string): ModelPricing | null {
|
||||
const pricing = loadPricingData();
|
||||
|
||||
const tryGet = (key: string): ModelPricing | null => {
|
||||
const entry = pricing[key];
|
||||
if (
|
||||
entry &&
|
||||
typeof entry === 'object' &&
|
||||
'input_cost_per_token' in entry &&
|
||||
'output_cost_per_token' in entry
|
||||
) {
|
||||
return entry as ModelPricing;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
// Try exact match
|
||||
const exact = tryGet(modelName);
|
||||
if (exact) return exact;
|
||||
|
||||
// Try lowercase
|
||||
const lowerName = modelName.toLowerCase();
|
||||
const lower = tryGet(lowerName);
|
||||
if (lower) return lower;
|
||||
|
||||
// Try case-insensitive search
|
||||
for (const key of Object.keys(pricing)) {
|
||||
if (key.toLowerCase() === lowerName) {
|
||||
const match = tryGet(key);
|
||||
if (match) return match;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Metrics Calculation
|
||||
// =============================================================================
|
||||
|
|
@ -335,7 +229,6 @@ export function calculateMetrics(messages: ParsedMessage[]): SessionMetrics {
|
|||
let outputTokens = 0;
|
||||
let cacheReadTokens = 0;
|
||||
let cacheCreationTokens = 0;
|
||||
let modelName: string | undefined;
|
||||
|
||||
// Get timestamps for duration (loop instead of Math.min/max spread to avoid stack overflow on large sessions)
|
||||
const timestamps = messages.map((m) => m.timestamp.getTime()).filter((t) => !isNaN(t));
|
||||
|
|
@ -366,36 +259,14 @@ export function calculateMetrics(messages: ParsedMessage[]): SessionMetrics {
|
|||
cacheReadTokens += msgCacheReadTokens;
|
||||
cacheCreationTokens += msgCacheCreationTokens;
|
||||
|
||||
// Calculate cost for this message if we have pricing data
|
||||
if (msg.model && !modelName) {
|
||||
modelName = msg.model;
|
||||
}
|
||||
|
||||
if (msg.model) {
|
||||
const pricing = getPricing(msg.model);
|
||||
if (pricing) {
|
||||
const inputCost = calculateTieredCost(
|
||||
msgInputTokens,
|
||||
pricing.input_cost_per_token,
|
||||
pricing.input_cost_per_token_above_200k_tokens
|
||||
);
|
||||
const outputCost = calculateTieredCost(
|
||||
msgOutputTokens,
|
||||
pricing.output_cost_per_token,
|
||||
pricing.output_cost_per_token_above_200k_tokens
|
||||
);
|
||||
const cacheCreationCost = calculateTieredCost(
|
||||
msgCacheCreationTokens,
|
||||
pricing.cache_creation_input_token_cost ?? 0,
|
||||
pricing.cache_creation_input_token_cost_above_200k_tokens
|
||||
);
|
||||
const cacheReadCost = calculateTieredCost(
|
||||
msgCacheReadTokens,
|
||||
pricing.cache_read_input_token_cost ?? 0,
|
||||
pricing.cache_read_input_token_cost_above_200k_tokens
|
||||
);
|
||||
costUsd += inputCost + outputCost + cacheCreationCost + cacheReadCost;
|
||||
}
|
||||
costUsd += calculateMessageCost(
|
||||
msg.model,
|
||||
msgInputTokens,
|
||||
msgOutputTokens,
|
||||
msgCacheReadTokens,
|
||||
msgCacheCreationTokens
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ export const ChatHistory = ({ tabId }: ChatHistoryProps): JSX.Element => {
|
|||
syncSearchMatchesWithRendered,
|
||||
selectSearchMatch,
|
||||
setTabVisibleAIGroup,
|
||||
openSessionReport,
|
||||
} = useStore(
|
||||
useShallow((s) => ({
|
||||
searchQuery: s.searchQuery,
|
||||
|
|
@ -76,6 +77,7 @@ export const ChatHistory = ({ tabId }: ChatHistoryProps): JSX.Element => {
|
|||
syncSearchMatchesWithRendered: s.syncSearchMatchesWithRendered,
|
||||
selectSearchMatch: s.selectSearchMatch,
|
||||
setTabVisibleAIGroup: s.setTabVisibleAIGroup,
|
||||
openSessionReport: s.openSessionReport,
|
||||
}))
|
||||
);
|
||||
|
||||
|
|
@ -100,6 +102,14 @@ export const ChatHistory = ({ tabId }: ChatHistoryProps): JSX.Element => {
|
|||
sessionDetail,
|
||||
} = tabData;
|
||||
|
||||
// Compute combined subagent cost from process metrics
|
||||
const subagentCostUsd = useMemo(() => {
|
||||
const processes = sessionDetail?.processes;
|
||||
if (!processes || processes.length === 0) return undefined;
|
||||
const total = processes.reduce((sum, p) => sum + (p.metrics.costUsd ?? 0), 0);
|
||||
return total > 0 ? total : undefined;
|
||||
}, [sessionDetail?.processes]);
|
||||
|
||||
// State for Context button hover (local state OK - doesn't need per-tab isolation)
|
||||
const [isContextButtonHovered, setIsContextButtonHovered] = useState(false);
|
||||
|
||||
|
|
@ -872,6 +882,8 @@ export const ChatHistory = ({ tabId }: ChatHistoryProps): JSX.Element => {
|
|||
onNavigateToUserGroup={handleNavigateToUserGroup}
|
||||
totalSessionTokens={lastAiGroupTotalTokens}
|
||||
sessionMetrics={sessionDetail?.metrics}
|
||||
subagentCostUsd={subagentCostUsd}
|
||||
onViewReport={effectiveTabId ? () => openSessionReport(effectiveTabId) : undefined}
|
||||
phaseInfo={sessionPhaseInfo ?? undefined}
|
||||
selectedPhase={selectedContextPhase}
|
||||
onPhaseChange={setSelectedContextPhase}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,9 @@ interface SessionContextHeaderProps {
|
|||
totalTokens: number;
|
||||
totalSessionTokens?: number;
|
||||
sessionMetrics?: SessionMetrics;
|
||||
subagentCostUsd?: number;
|
||||
onClose?: () => void;
|
||||
onViewReport?: () => void;
|
||||
phaseInfo?: ContextPhaseInfo;
|
||||
selectedPhase: number | null;
|
||||
onPhaseChange: (phase: number | null) => void;
|
||||
|
|
@ -41,7 +43,9 @@ export const SessionContextHeader = ({
|
|||
totalTokens,
|
||||
totalSessionTokens,
|
||||
sessionMetrics,
|
||||
subagentCostUsd,
|
||||
onClose,
|
||||
onViewReport,
|
||||
phaseInfo,
|
||||
selectedPhase,
|
||||
onPhaseChange,
|
||||
|
|
@ -130,8 +134,30 @@ export const SessionContextHeader = ({
|
|||
<div className="col-span-2">
|
||||
<span style={{ color: COLOR_TEXT_MUTED }}>Session Cost: </span>
|
||||
<span className="font-medium tabular-nums" style={{ color: COLOR_TEXT_SECONDARY }}>
|
||||
{formatCostUsd(sessionMetrics.costUsd)}
|
||||
{formatCostUsd(sessionMetrics.costUsd + (subagentCostUsd ?? 0))}
|
||||
</span>
|
||||
{subagentCostUsd !== undefined && subagentCostUsd > 0 && (
|
||||
<span style={{ color: COLOR_TEXT_MUTED }}>
|
||||
{' ('}
|
||||
{formatCostUsd(sessionMetrics.costUsd)}
|
||||
{' parent + '}
|
||||
{formatCostUsd(subagentCostUsd)}
|
||||
{' subagents'}
|
||||
{onViewReport && (
|
||||
<>
|
||||
{' · '}
|
||||
<button
|
||||
onClick={onViewReport}
|
||||
className="underline"
|
||||
style={{ color: COLOR_TEXT_SECONDARY }}
|
||||
>
|
||||
details
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
{')'}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -49,6 +49,8 @@ export const SessionContextPanel = ({
|
|||
onNavigateToUserGroup,
|
||||
totalSessionTokens,
|
||||
sessionMetrics,
|
||||
subagentCostUsd,
|
||||
onViewReport,
|
||||
phaseInfo,
|
||||
selectedPhase,
|
||||
onPhaseChange,
|
||||
|
|
@ -192,7 +194,9 @@ export const SessionContextPanel = ({
|
|||
totalTokens={totalTokens}
|
||||
totalSessionTokens={totalSessionTokens}
|
||||
sessionMetrics={sessionMetrics}
|
||||
subagentCostUsd={subagentCostUsd}
|
||||
onClose={onClose}
|
||||
onViewReport={onViewReport}
|
||||
phaseInfo={phaseInfo}
|
||||
selectedPhase={selectedPhase}
|
||||
onPhaseChange={onPhaseChange}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,10 @@ export interface SessionContextPanelProps {
|
|||
totalSessionTokens?: number;
|
||||
/** Full session metrics (input, output, cache tokens, cost) */
|
||||
sessionMetrics?: SessionMetrics;
|
||||
/** Combined cost of all subagent processes */
|
||||
subagentCostUsd?: number;
|
||||
/** Open the Session Report to see full cost breakdown */
|
||||
onViewReport?: () => void;
|
||||
/** Phase information for phase selector */
|
||||
phaseInfo?: ContextPhaseInfo;
|
||||
/** Currently selected phase (null = current/latest) */
|
||||
|
|
|
|||
|
|
@ -20,12 +20,7 @@ import type {
|
|||
// Pricing
|
||||
// =============================================================================
|
||||
|
||||
export interface ModelPricing {
|
||||
input: number;
|
||||
output: number;
|
||||
cache_read: number;
|
||||
cache_creation: number;
|
||||
}
|
||||
export type { DisplayPricing as ModelPricing } from '@shared/utils/pricing';
|
||||
|
||||
// =============================================================================
|
||||
// Report Sections
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import {
|
|||
detectModelMismatch,
|
||||
detectSwitchPattern,
|
||||
} from '@renderer/utils/reportAssessments';
|
||||
import { calculateMessageCost } from '@shared/utils/pricing';
|
||||
|
||||
import type {
|
||||
AgentTreeNode,
|
||||
|
|
@ -29,7 +30,6 @@ import type {
|
|||
GitCommit,
|
||||
IdleGap,
|
||||
KeyEvent,
|
||||
ModelPricing,
|
||||
ModelSwitch,
|
||||
ModelTokenStats,
|
||||
OutOfScopeFindings,
|
||||
|
|
@ -53,81 +53,8 @@ import type {
|
|||
ToolCall,
|
||||
} from '@shared/types';
|
||||
|
||||
// =============================================================================
|
||||
// Pricing Table (USD per 1M tokens)
|
||||
// =============================================================================
|
||||
|
||||
const MODEL_PRICING: Record<string, ModelPricing> = {
|
||||
'opus-4': {
|
||||
input: 15.0,
|
||||
output: 75.0,
|
||||
cache_read: 1.5,
|
||||
cache_creation: 18.75,
|
||||
},
|
||||
'sonnet-4': {
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
cache_read: 0.3,
|
||||
cache_creation: 3.75,
|
||||
},
|
||||
'haiku-4': {
|
||||
input: 0.8,
|
||||
output: 4.0,
|
||||
cache_read: 0.08,
|
||||
cache_creation: 1.0,
|
||||
},
|
||||
'opus-3': {
|
||||
input: 15.0,
|
||||
output: 75.0,
|
||||
cache_read: 1.5,
|
||||
cache_creation: 18.75,
|
||||
},
|
||||
'sonnet-3': {
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
cache_read: 0.3,
|
||||
cache_creation: 3.75,
|
||||
},
|
||||
'haiku-3': {
|
||||
input: 0.25,
|
||||
output: 1.25,
|
||||
cache_read: 0.03,
|
||||
cache_creation: 0.3,
|
||||
},
|
||||
};
|
||||
|
||||
const DEFAULT_PRICING: ModelPricing = {
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
cache_read: 0.3,
|
||||
cache_creation: 3.75,
|
||||
};
|
||||
|
||||
export function getPricing(modelName: string): ModelPricing {
|
||||
const nameTokens: string[] = modelName.toLowerCase().match(/[a-z0-9]+/g) ?? [];
|
||||
for (const [key, pricing] of Object.entries(MODEL_PRICING)) {
|
||||
const keyTokens: string[] = key.match(/[a-z0-9]+/g) ?? [];
|
||||
if (keyTokens.every((t) => nameTokens.includes(t))) return pricing;
|
||||
}
|
||||
return DEFAULT_PRICING;
|
||||
}
|
||||
|
||||
function costUsd(
|
||||
modelName: string,
|
||||
inputTok: number,
|
||||
outputTok: number,
|
||||
cacheReadTok: number,
|
||||
cacheCreationTok: number
|
||||
): number {
|
||||
const p = getPricing(modelName);
|
||||
return (
|
||||
(inputTok * p.input +
|
||||
outputTok * p.output +
|
||||
cacheReadTok * p.cache_read +
|
||||
cacheCreationTok * p.cache_creation) /
|
||||
1_000_000
|
||||
);
|
||||
}
|
||||
// Re-export getDisplayPricing as getPricing for backward compat with CostSection
|
||||
export { getDisplayPricing as getPricing } from '@shared/utils/pricing';
|
||||
|
||||
// =============================================================================
|
||||
// Helpers
|
||||
|
|
@ -473,7 +400,7 @@ export function analyzeSession(detail: SessionDetail): SessionReport {
|
|||
stats.cacheCreation += cc;
|
||||
stats.cacheRead += cr;
|
||||
|
||||
const callCost = costUsd(model, inpTok, outTok, cr, cc);
|
||||
const callCost = calculateMessageCost(model, inpTok, outTok, cr, cc);
|
||||
stats.costUsd += callCost;
|
||||
parentCost += callCost;
|
||||
|
||||
|
|
@ -897,7 +824,7 @@ export function analyzeSession(detail: SessionDetail): SessionReport {
|
|||
proc.messages.find((m: ParsedMessage) => m.type === 'assistant' && m.model)?.model ??
|
||||
'default (inherits parent)';
|
||||
// Compute cost from subagent token breakdown (proc.metrics.costUsd is not populated upstream)
|
||||
const computedCost = costUsd(
|
||||
const computedCost = calculateMessageCost(
|
||||
subagentModel,
|
||||
proc.metrics.inputTokens,
|
||||
proc.metrics.outputTokens,
|
||||
|
|
|
|||
121
src/shared/utils/pricing.ts
Normal file
121
src/shared/utils/pricing.ts
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
// eslint-disable-next-line no-restricted-imports -- resources/ is outside src/, no alias available
|
||||
import pricingData from '../../../resources/pricing.json';
|
||||
|
||||
export interface LiteLLMPricing {
|
||||
input_cost_per_token: number;
|
||||
output_cost_per_token: number;
|
||||
cache_creation_input_token_cost?: number;
|
||||
cache_read_input_token_cost?: number;
|
||||
input_cost_per_token_above_200k_tokens?: number;
|
||||
output_cost_per_token_above_200k_tokens?: number;
|
||||
cache_creation_input_token_cost_above_200k_tokens?: number;
|
||||
cache_read_input_token_cost_above_200k_tokens?: number;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface DisplayPricing {
|
||||
input: number;
|
||||
output: number;
|
||||
cache_read: number;
|
||||
cache_creation: number;
|
||||
}
|
||||
|
||||
const TIER_THRESHOLD = 200_000;
|
||||
|
||||
const PRICING_MAP = pricingData as Record<string, unknown>;
|
||||
|
||||
// Pre-compute lowercase key map for O(1) case-insensitive lookups
|
||||
const LOWERCASE_KEY_MAP = new Map<string, string>();
|
||||
for (const key of Object.keys(PRICING_MAP)) {
|
||||
if (!LOWERCASE_KEY_MAP.has(key.toLowerCase())) {
|
||||
LOWERCASE_KEY_MAP.set(key.toLowerCase(), key);
|
||||
}
|
||||
}
|
||||
|
||||
function isLiteLLMPricing(entry: unknown): entry is LiteLLMPricing {
|
||||
return (
|
||||
!!entry &&
|
||||
typeof entry === 'object' &&
|
||||
'input_cost_per_token' in entry &&
|
||||
'output_cost_per_token' in entry
|
||||
);
|
||||
}
|
||||
|
||||
function tryGetPricing(key: string): LiteLLMPricing | null {
|
||||
const entry = PRICING_MAP[key];
|
||||
return isLiteLLMPricing(entry) ? entry : null;
|
||||
}
|
||||
|
||||
export function getPricing(modelName: string): LiteLLMPricing | null {
|
||||
const exact = tryGetPricing(modelName);
|
||||
if (exact) return exact;
|
||||
|
||||
const lowerName = modelName.toLowerCase();
|
||||
const originalKey = LOWERCASE_KEY_MAP.get(lowerName);
|
||||
if (originalKey) {
|
||||
return tryGetPricing(originalKey);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
export function calculateTieredCost(tokens: number, baseRate: number, tieredRate?: number): number {
|
||||
if (tokens <= 0) return 0;
|
||||
if (tieredRate == null || tokens <= TIER_THRESHOLD) {
|
||||
return tokens * baseRate;
|
||||
}
|
||||
const costBelow = TIER_THRESHOLD * baseRate;
|
||||
const costAbove = (tokens - TIER_THRESHOLD) * tieredRate;
|
||||
return costBelow + costAbove;
|
||||
}
|
||||
|
||||
export function calculateMessageCost(
|
||||
modelName: string,
|
||||
inputTokens: number,
|
||||
outputTokens: number,
|
||||
cacheReadTokens: number,
|
||||
cacheCreationTokens: number
|
||||
): number {
|
||||
const pricing = getPricing(modelName);
|
||||
if (!pricing) {
|
||||
if (inputTokens > 0 || outputTokens > 0 || cacheReadTokens > 0 || cacheCreationTokens > 0) {
|
||||
console.warn(`[pricing] No pricing data for model "${modelName}", cost will be $0`);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
const inputCost = calculateTieredCost(
|
||||
inputTokens,
|
||||
pricing.input_cost_per_token,
|
||||
pricing.input_cost_per_token_above_200k_tokens
|
||||
);
|
||||
const outputCost = calculateTieredCost(
|
||||
outputTokens,
|
||||
pricing.output_cost_per_token,
|
||||
pricing.output_cost_per_token_above_200k_tokens
|
||||
);
|
||||
const cacheCreationCost = calculateTieredCost(
|
||||
cacheCreationTokens,
|
||||
pricing.cache_creation_input_token_cost ?? 0,
|
||||
pricing.cache_creation_input_token_cost_above_200k_tokens
|
||||
);
|
||||
const cacheReadCost = calculateTieredCost(
|
||||
cacheReadTokens,
|
||||
pricing.cache_read_input_token_cost ?? 0,
|
||||
pricing.cache_read_input_token_cost_above_200k_tokens
|
||||
);
|
||||
|
||||
return inputCost + outputCost + cacheCreationCost + cacheReadCost;
|
||||
}
|
||||
|
||||
export function getDisplayPricing(modelName: string): DisplayPricing | null {
|
||||
const pricing = getPricing(modelName);
|
||||
if (!pricing) return null;
|
||||
|
||||
return {
|
||||
input: pricing.input_cost_per_token * 1_000_000,
|
||||
output: pricing.output_cost_per_token * 1_000_000,
|
||||
cache_read: (pricing.cache_read_input_token_cost ?? 0) * 1_000_000,
|
||||
cache_creation: (pricing.cache_creation_input_token_cost ?? 0) * 1_000_000,
|
||||
};
|
||||
}
|
||||
|
|
@ -2,44 +2,11 @@
|
|||
* Tests for cost calculation in jsonl.ts
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import * as fs from 'fs';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { calculateMetrics } from '@main/utils/jsonl';
|
||||
import type { ParsedMessage } from '@main/types';
|
||||
|
||||
// Mock fs module
|
||||
vi.mock('fs');
|
||||
|
||||
describe('Cost Calculation', () => {
|
||||
// Sample pricing data matching Claude models
|
||||
const mockPricingData = {
|
||||
'claude-3-5-sonnet-20241022': {
|
||||
input_cost_per_token: 0.000003,
|
||||
output_cost_per_token: 0.000015,
|
||||
cache_creation_input_token_cost: 0.00000375,
|
||||
cache_read_input_token_cost: 0.0000003,
|
||||
input_cost_per_token_above_200k_tokens: 0.000006,
|
||||
output_cost_per_token_above_200k_tokens: 0.00003,
|
||||
cache_creation_input_token_cost_above_200k_tokens: 0.0000075,
|
||||
cache_read_input_token_cost_above_200k_tokens: 0.0000006,
|
||||
},
|
||||
'claude-3-opus-20240229': {
|
||||
input_cost_per_token: 0.000015,
|
||||
output_cost_per_token: 0.000075,
|
||||
cache_creation_input_token_cost: 0.00001875,
|
||||
cache_read_input_token_cost: 0.0000015,
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset modules to clear pricing cache
|
||||
vi.resetModules();
|
||||
|
||||
// Mock fs.readFileSync to return our test pricing data
|
||||
vi.mocked(fs.readFileSync).mockReturnValue(JSON.stringify(mockPricingData));
|
||||
vi.mocked(fs.existsSync).mockReturnValue(true);
|
||||
});
|
||||
|
||||
describe('Basic Cost Calculation', () => {
|
||||
it('should calculate cost for simple token usage', () => {
|
||||
const messages: ParsedMessage[] = [
|
||||
|
|
@ -117,6 +84,7 @@ describe('Cost Calculation', () => {
|
|||
});
|
||||
|
||||
it('should return 0 cost when model pricing not found', () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => undefined);
|
||||
const messages: ParsedMessage[] = [
|
||||
{
|
||||
type: 'assistant',
|
||||
|
|
@ -136,6 +104,7 @@ describe('Cost Calculation', () => {
|
|||
|
||||
const metrics = calculateMetrics(messages);
|
||||
expect(metrics.costUsd).toBe(0);
|
||||
warnSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -166,7 +135,7 @@ describe('Cost Calculation', () => {
|
|||
expect(metrics.costUsd).toBeCloseTo(1.05, 6);
|
||||
});
|
||||
|
||||
it('should use tiered rates for input tokens above 200k threshold', () => {
|
||||
it('should use base rates for input tokens above 200k when model has no tiered pricing', () => {
|
||||
const messages: ParsedMessage[] = [
|
||||
{
|
||||
type: 'assistant',
|
||||
|
|
@ -186,13 +155,14 @@ describe('Cost Calculation', () => {
|
|||
|
||||
const metrics = calculateMetrics(messages);
|
||||
|
||||
// Input: (200000 * 0.000003) + (50000 * 0.000006) = 0.6 + 0.3 = 0.9
|
||||
// claude-3-5-sonnet-20241022 has no tiered rates in pricing.json, so base rates apply
|
||||
// Input: 250000 * 0.000003 = 0.75
|
||||
// Output: 1000 * 0.000015 = 0.015
|
||||
// Total: 0.915
|
||||
expect(metrics.costUsd).toBeCloseTo(0.915, 6);
|
||||
// Total: 0.765
|
||||
expect(metrics.costUsd).toBeCloseTo(0.765, 6);
|
||||
});
|
||||
|
||||
it('should use tiered rates for output tokens above 200k threshold', () => {
|
||||
it('should use base rates for output tokens above 200k when model has no tiered pricing', () => {
|
||||
const messages: ParsedMessage[] = [
|
||||
{
|
||||
type: 'assistant',
|
||||
|
|
@ -212,13 +182,14 @@ describe('Cost Calculation', () => {
|
|||
|
||||
const metrics = calculateMetrics(messages);
|
||||
|
||||
// No tiered rates, so base rates for all tokens
|
||||
// Input: 1000 * 0.000003 = 0.003
|
||||
// Output: (200000 * 0.000015) + (50000 * 0.00003) = 3.0 + 1.5 = 4.5
|
||||
// Total: 4.503
|
||||
expect(metrics.costUsd).toBeCloseTo(4.503, 6);
|
||||
// Output: 250000 * 0.000015 = 3.75
|
||||
// Total: 3.753
|
||||
expect(metrics.costUsd).toBeCloseTo(3.753, 6);
|
||||
});
|
||||
|
||||
it('should use tiered rates for cache tokens above 200k threshold', () => {
|
||||
it('should use base rates for cache tokens above 200k when model has no tiered pricing', () => {
|
||||
const messages: ParsedMessage[] = [
|
||||
{
|
||||
type: 'assistant',
|
||||
|
|
@ -240,12 +211,13 @@ describe('Cost Calculation', () => {
|
|||
|
||||
const metrics = calculateMetrics(messages);
|
||||
|
||||
// No tiered rates for this model, so base rates apply
|
||||
// Input: 1000 * 0.000003 = 0.003
|
||||
// Output: 1000 * 0.000015 = 0.015
|
||||
// Cache creation: (200000 * 0.00000375) + (50000 * 0.0000075) = 0.75 + 0.375 = 1.125
|
||||
// Cache read: (200000 * 0.0000003) + (50000 * 0.0000006) = 0.06 + 0.03 = 0.09
|
||||
// Total: 1.233
|
||||
expect(metrics.costUsd).toBeCloseTo(1.233, 6);
|
||||
// Cache creation: 250000 * 0.00000375 = 0.9375
|
||||
// Cache read: 250000 * 0.0000003 = 0.075
|
||||
// Total: 1.0305
|
||||
expect(metrics.costUsd).toBeCloseTo(1.0305, 6);
|
||||
});
|
||||
|
||||
it('should handle model without tiered pricing', () => {
|
||||
|
|
@ -274,6 +246,34 @@ describe('Cost Calculation', () => {
|
|||
// Total: 22.5
|
||||
expect(metrics.costUsd).toBeCloseTo(22.5, 6);
|
||||
});
|
||||
|
||||
it('should use tiered rates for a model that has them (claude-4-sonnet)', () => {
|
||||
const messages: ParsedMessage[] = [
|
||||
{
|
||||
type: 'assistant',
|
||||
uuid: 'msg-1',
|
||||
timestamp: new Date(),
|
||||
content: [],
|
||||
model: 'claude-4-sonnet-20250514',
|
||||
usage: {
|
||||
input_tokens: 250_000,
|
||||
output_tokens: 1_000,
|
||||
},
|
||||
toolCalls: [],
|
||||
toolResults: [],
|
||||
isSidechain: false,
|
||||
},
|
||||
];
|
||||
|
||||
const metrics = calculateMetrics(messages);
|
||||
|
||||
// claude-4-sonnet has tiered rates:
|
||||
// input base=0.000003, above_200k=0.000006
|
||||
// Input: (200000 * 0.000003) + (50000 * 0.000006) = 0.6 + 0.3 = 0.9
|
||||
// Output: 1000 * 0.000015 = 0.015
|
||||
// Total: 0.915
|
||||
expect(metrics.costUsd).toBeCloseTo(0.915, 6);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Multiple Messages', () => {
|
||||
|
|
@ -405,48 +405,6 @@ describe('Cost Calculation', () => {
|
|||
const metrics = calculateMetrics(messages);
|
||||
expect(metrics.costUsd).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle pricing data load failure gracefully', async () => {
|
||||
// Suppress expected console.error for this test
|
||||
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
// Reset modules to clear the pricing cache
|
||||
vi.resetModules();
|
||||
|
||||
// Mock fs to throw error BEFORE importing calculateMetrics
|
||||
vi.mocked(fs.readFileSync).mockImplementation(() => {
|
||||
throw new Error('File not found');
|
||||
});
|
||||
|
||||
// Re-import calculateMetrics to get fresh instance with cleared cache
|
||||
const { calculateMetrics: freshCalculateMetrics } = await import('@main/utils/jsonl');
|
||||
|
||||
const messages: ParsedMessage[] = [
|
||||
{
|
||||
type: 'assistant',
|
||||
uuid: 'msg-1',
|
||||
timestamp: new Date(),
|
||||
content: [],
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
usage: {
|
||||
input_tokens: 1000,
|
||||
output_tokens: 500,
|
||||
},
|
||||
toolCalls: [],
|
||||
toolResults: [],
|
||||
isSidechain: false,
|
||||
},
|
||||
];
|
||||
|
||||
const metrics = freshCalculateMetrics(messages);
|
||||
expect(metrics.costUsd).toBe(0);
|
||||
|
||||
// Verify that console.error was called (error was logged)
|
||||
expect(consoleErrorSpy).toHaveBeenCalled();
|
||||
|
||||
// Restore console.error
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model Name Lookup', () => {
|
||||
|
|
@ -535,7 +493,7 @@ describe('Cost Calculation', () => {
|
|||
expect(metrics.costUsd).not.toBeCloseTo(incorrectAggregatedCost, 2);
|
||||
});
|
||||
|
||||
it('should apply tiered rates when individual messages exceed 200k', () => {
|
||||
it('should use base rates when individual messages exceed 200k and model has no tiered rates', () => {
|
||||
const messages: ParsedMessage[] = [
|
||||
{
|
||||
type: 'assistant',
|
||||
|
|
@ -556,11 +514,9 @@ describe('Cost Calculation', () => {
|
|||
|
||||
const metrics = calculateMetrics(messages);
|
||||
|
||||
// Single message with 300k cache_read tokens
|
||||
// First 200k: 200,000 * 0.0000003 = $0.06
|
||||
// Remaining 100k: 100,000 * 0.0000006 = $0.06
|
||||
// Total: $0.12
|
||||
const expectedCost = 200000 * 0.0000003 + 100000 * 0.0000006;
|
||||
// No tiered rates for this model, so all 300k at base rate
|
||||
// 300,000 * 0.0000003 = $0.09
|
||||
const expectedCost = 300000 * 0.0000003;
|
||||
expect(metrics.costUsd).toBeCloseTo(expectedCost, 6);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
85
test/shared/utils/pricing.test.ts
Normal file
85
test/shared/utils/pricing.test.ts
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
import { describe, it, expect, vi } from 'vitest';
|
||||
import {
|
||||
getPricing,
|
||||
calculateTieredCost,
|
||||
calculateMessageCost,
|
||||
getDisplayPricing,
|
||||
} from '@shared/utils/pricing';
|
||||
|
||||
describe('Shared Pricing Module', () => {
|
||||
describe('getPricing', () => {
|
||||
it('should find pricing by exact model name', () => {
|
||||
const pricing = getPricing('claude-3-5-sonnet-20241022');
|
||||
expect(pricing).not.toBeNull();
|
||||
expect(pricing!.input_cost_per_token).toBeGreaterThan(0);
|
||||
expect(pricing!.output_cost_per_token).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should find pricing case-insensitively', () => {
|
||||
const pricing = getPricing('Claude-3-5-Sonnet-20241022');
|
||||
expect(pricing).not.toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for unknown models', () => {
|
||||
const pricing = getPricing('totally-fake-model-xyz');
|
||||
expect(pricing).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateTieredCost', () => {
|
||||
it('should use base rate for tokens below 200k', () => {
|
||||
const cost = calculateTieredCost(100_000, 0.000003);
|
||||
expect(cost).toBeCloseTo(0.3, 6);
|
||||
});
|
||||
|
||||
it('should apply tiered rate above 200k', () => {
|
||||
const cost = calculateTieredCost(250_000, 0.000003, 0.000006);
|
||||
expect(cost).toBeCloseTo(0.9, 6);
|
||||
});
|
||||
|
||||
it('should use base rate when no tiered rate provided', () => {
|
||||
const cost = calculateTieredCost(250_000, 0.000015);
|
||||
expect(cost).toBeCloseTo(3.75, 6);
|
||||
});
|
||||
|
||||
it('should return 0 for zero or negative tokens', () => {
|
||||
expect(calculateTieredCost(0, 0.000003)).toBe(0);
|
||||
expect(calculateTieredCost(-100, 0.000003)).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateMessageCost', () => {
|
||||
it('should compute cost for a known model', () => {
|
||||
const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 0, 0);
|
||||
expect(cost).toBeCloseTo(0.0105, 6);
|
||||
});
|
||||
|
||||
it('should return 0 for unknown models', () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => undefined);
|
||||
const cost = calculateMessageCost('unknown-model', 1000, 500, 0, 0);
|
||||
expect(cost).toBe(0);
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'[pricing] No pricing data for model "unknown-model", cost will be $0'
|
||||
);
|
||||
warnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should include cache token costs', () => {
|
||||
const cost = calculateMessageCost('claude-3-5-sonnet-20241022', 1000, 500, 300, 200);
|
||||
expect(cost).toBeGreaterThan(0.0105);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDisplayPricing', () => {
|
||||
it('should return per-million rates for a known model', () => {
|
||||
const dp = getDisplayPricing('claude-3-5-sonnet-20241022');
|
||||
expect(dp).not.toBeNull();
|
||||
expect(dp!.input).toBeCloseTo(3.0, 1);
|
||||
expect(dp!.output).toBeCloseTo(15.0, 1);
|
||||
});
|
||||
|
||||
it('should return null for unknown models', () => {
|
||||
expect(getDisplayPricing('unknown-model')).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
Loading…
Reference in a new issue