feat(worker): let Claude set the cheapest model per generated task via MCP

AddTask, planning CreateChildTask, and SuggestImprovement now accept an
optional alias-validated model (haiku/sonnet/opus; blank = inherit) so the
model is chosen at creation time instead of a follow-up set_task_config call.
The planning, system, and improvement prompts instruct Claude to pick the
cheapest capable model (haiku < sonnet < opus).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
mika kuns
2026-06-09 22:22:17 +02:00
parent 1448794748
commit c27a179d2b
12 changed files with 181 additions and 18 deletions

View File

@@ -4,9 +4,26 @@ public static class ModelRegistry
{ {
public static readonly IReadOnlyList<string> Aliases = new[] { "sonnet", "opus", "haiku" }; public static readonly IReadOnlyList<string> Aliases = new[] { "sonnet", "opus", "haiku" };
/// <summary>Model aliases ordered cheapest → most capable. Single source for prompt cost guidance.</summary>
public static readonly IReadOnlyList<string> ByCostAscending = new[] { "haiku", "sonnet", "opus" };
public const string DefaultAlias = "sonnet"; public const string DefaultAlias = "sonnet";
public const string PlanningAlias = "opus"; public const string PlanningAlias = "opus";
public const string ListDefaultSentinel = "(default)"; public const string ListDefaultSentinel = "(default)";
public const string TaskInheritSentinel = "(inherit)"; public const string TaskInheritSentinel = "(inherit)";
/// <summary>
/// Validate a model alias from external input. Null/blank → null (inherit).
/// Returns the canonical lowercase alias; throws on an unknown value.
/// </summary>
public static string? NormalizeAlias(string? model)
{
var m = model?.Trim();
if (string.IsNullOrEmpty(m)) return null;
foreach (var alias in Aliases)
if (string.Equals(alias, m, StringComparison.OrdinalIgnoreCase))
return alias;
throw new ArgumentException($"Unknown model '{model}'. Allowed: {string.Join(", ", Aliases)}.");
}
} }

View File

@@ -82,7 +82,10 @@ public static class PromptFiles
## Out-of-scope improvements ## Out-of-scope improvements
If you notice worthwhile work that is genuinely outside this task's scope If you notice worthwhile work that is genuinely outside this task's scope
(a refactor, a follow-up, tech debt), do NOT do it here. File it with (a refactor, a follow-up, tech debt), do NOT do it here. File it with
SuggestImprovement(title, description) and stay focused on the task at hand. SuggestImprovement(title, description, model) and stay focused on the task at hand.
Set `model` to the cheapest model that can do the follow-up well 'haiku' for
trivial/mechanical work, 'sonnet' for normal coding, 'opus' only for genuinely
complex work (cheapest to most capable: haiku < sonnet < opus).
## Working in the repo ## Working in the repo
- Read a file before editing it. Match the conventions already in this codebase - Read a file before editing it. Match the conventions already in this codebase
@@ -122,8 +125,8 @@ public static class PromptFiles
# Out-of-scope follow-up # Out-of-scope follow-up
You are an improvement follow-up that another task filed via SuggestImprovement. You are an improvement follow-up that another task filed via SuggestImprovement.
It was deliberately scoped narrow. Do EXACTLY what this task's title and It was deliberately scoped narrow, and is intentionally a small, cheap unit of
description ask nothing more. work. Do EXACTLY what this task's title and description ask nothing more.
- Make the smallest change that satisfies the task. No opportunistic refactors, - Make the smallest change that satisfies the task. No opportunistic refactors,
renames, reformatting, or "while I'm here" cleanup beyond what is asked. renames, reformatting, or "while I'm here" cleanup beyond what is asked.
@@ -150,6 +153,14 @@ public static class PromptFiles
Once the design is approved, create the child tasks with CreateChildTask, then Once the design is approved, create the child tasks with CreateChildTask, then
call Finalize. Keep each subtask concrete and self-contained with a clear call Finalize. Keep each subtask concrete and self-contained with a clear
done-state, ordered so dependencies come first. done-state, ordered so dependencies come first.
For each subtask, pass CreateChildTask's `model` argument set to the CHEAPEST
model that can do that subtask well. Models, cheapest to most capable:
haiku < sonnet < opus.
- haiku trivial/mechanical work: doc tweaks, simple renames, small localized edits.
- sonnet normal coding work; the sensible default when unsure.
- opus only for genuinely complex, cross-cutting, or hard-to-debug work.
Do not default everything to opus most subtasks are haiku or sonnet.
"""; """;
private const string PlanningInitialDefault = """ private const string PlanningInitialDefault = """

View File

@@ -197,6 +197,7 @@ public sealed class TaskRepository
string? description, string? description,
string? commitType, string? commitType,
string? createdBy = null, string? createdBy = null,
string? model = null,
CancellationToken ct = default) CancellationToken ct = default)
{ {
// AsNoTracking: SetPlanningStartedAsync mutates via ExecuteUpdate which // AsNoTracking: SetPlanningStartedAsync mutates via ExecuteUpdate which
@@ -223,6 +224,7 @@ public sealed class TaskRepository
ParentTaskId = parentId, ParentTaskId = parentId,
SortOrder = (maxSort ?? -1) + 1, SortOrder = (maxSort ?? -1) + 1,
CreatedBy = createdBy, CreatedBy = createdBy,
Model = ModelRegistry.NormalizeAlias(model),
}; };
_context.Tasks.Add(child); _context.Tasks.Add(child);
await _context.SaveChangesAsync(ct); await _context.SaveChangesAsync(ct);

View File

@@ -166,7 +166,7 @@ Loaded from `~/.todo-app/worker.config.json`:
- `signalr_port` (default 47821) - `signalr_port` (default 47821)
- `claude_bin` (path to claude CLI) - `claude_bin` (path to claude CLI)
Per-list config (`list_config` in DB) provides defaults for `model`, `system_prompt`, `agent_path`; tasks can override each individually. Per-list config (`list_config` in DB) provides defaults for `model`, `system_prompt`, `agent_path`; tasks can override each individually. Task-generating MCP tools (`AddTask`, planning `CreateChildTask`, `SuggestImprovement`) accept an optional `model` (alias-validated via `ModelRegistry.NormalizeAlias` — `haiku`/`sonnet`/`opus`, blank = inherit) so Claude assigns the cheapest capable model at creation time; the planning/system/improvement prompts instruct it to do so (`ModelRegistry.ByCostAscending` = the cost order).
## Notes ## Notes

View File

@@ -142,13 +142,18 @@ public sealed class ExternalMcpService
return ToDto(task); return ToDto(task);
} }
[McpServerTool, Description("Create a new task in the given list. Set queueImmediately=true to enqueue it for agent execution.")] [McpServerTool, Description(
"Create a new task in the given list. Set queueImmediately=true to enqueue it for agent execution. " +
"Set model to the cheapest model that can do the task well — 'haiku' for trivial/mechanical work, " +
"'sonnet' for normal coding (the default), 'opus' only for complex or cross-cutting work. " +
"Leave model null to inherit the list/global default.")]
public async Task<TaskDto> AddTask( public async Task<TaskDto> AddTask(
string listId, string listId,
string title, string title,
string? description = null, string? description = null,
string? createdBy = null, string? createdBy = null,
bool queueImmediately = false, bool queueImmediately = false,
string? model = null,
CancellationToken cancellationToken = default) CancellationToken cancellationToken = default)
{ {
if (string.IsNullOrWhiteSpace(listId)) if (string.IsNullOrWhiteSpace(listId))
@@ -169,6 +174,7 @@ public sealed class ExternalMcpService
CreatedAt = DateTime.UtcNow, CreatedAt = DateTime.UtcNow,
CommitType = list.DefaultCommitType, CommitType = list.DefaultCommitType,
CreatedBy = createdBy.NullIfBlank() ?? "mcp", CreatedBy = createdBy.NullIfBlank() ?? "mcp",
Model = ModelRegistry.NormalizeAlias(model),
}; };
await _tasks.AddAsync(entity, cancellationToken); await _tasks.AddAsync(entity, cancellationToken);

View File

@@ -37,15 +37,20 @@ public sealed class PlanningMcpService
private Task BroadcastTaskUpdatedAsync(string taskId, CancellationToken ct) private Task BroadcastTaskUpdatedAsync(string taskId, CancellationToken ct)
=> _broadcaster.TaskUpdated(taskId); => _broadcaster.TaskUpdated(taskId);
[McpServerTool, Description("Create a new draft child task under the current planning session's parent task.")] [McpServerTool, Description(
"Create a new draft child task under the current planning session's parent task. " +
"Set model to the cheapest model that can do this subtask well — 'haiku' for trivial/mechanical " +
"work, 'sonnet' for normal coding (the default), 'opus' only for complex or cross-cutting work. " +
"Leave model null to inherit the list/global default.")]
public async Task<CreatedChildDto> CreateChildTask( public async Task<CreatedChildDto> CreateChildTask(
string title, string title,
string? description, string? description,
string? commitType, string? commitType,
string? model,
CancellationToken cancellationToken) CancellationToken cancellationToken)
{ {
var ctx = _contextAccessor.Current; var ctx = _contextAccessor.Current;
var child = await _tasks.CreateChildAsync(ctx.ParentTaskId, title, description, commitType, createdBy: null, cancellationToken); var child = await _tasks.CreateChildAsync(ctx.ParentTaskId, title, description, commitType, createdBy: null, model: model, ct: cancellationToken);
await BroadcastTaskUpdatedAsync(child.Id, cancellationToken); await BroadcastTaskUpdatedAsync(child.Id, cancellationToken);
await BroadcastTaskUpdatedAsync(ctx.ParentTaskId, cancellationToken); await BroadcastTaskUpdatedAsync(ctx.ParentTaskId, cancellationToken);
return new CreatedChildDto(child.Id, child.Status.ToString()); return new CreatedChildDto(child.Id, child.Status.ToString());

View File

@@ -25,10 +25,13 @@ public sealed class TaskRunMcpService
"File an out-of-scope improvement as a child task of the current task. The child runs " + "File an out-of-scope improvement as a child task of the current task. The child runs " +
"automatically after this task finishes and is surfaced for review alongside it. Use ONLY " + "automatically after this task finishes and is surfaced for review alongside it. Use ONLY " +
"for work that is genuinely outside this task's scope (a refactor, follow-up, or tech debt) " + "for work that is genuinely outside this task's scope (a refactor, follow-up, or tech debt) " +
"— never for work that belongs to the current task.")] "— never for work that belongs to the current task. Set model to the cheapest model that can " +
"do the follow-up well — 'haiku' for trivial/mechanical work, 'sonnet' for normal coding, " +
"'opus' only for complex work. Leave model null to inherit the list/global default.")]
public async Task<SuggestedImprovementDto> SuggestImprovement( public async Task<SuggestedImprovementDto> SuggestImprovement(
string title, string title,
string description, string description,
string? model,
CancellationToken cancellationToken) CancellationToken cancellationToken)
{ {
var callerId = _ctx.Current.CallerTaskId; var callerId = _ctx.Current.CallerTaskId;
@@ -39,7 +42,7 @@ public sealed class TaskRunMcpService
"A child task cannot suggest further improvements (improvements are one layer deep)."); "A child task cannot suggest further improvements (improvements are one layer deep).");
var child = await _tasks.CreateChildAsync( var child = await _tasks.CreateChildAsync(
callerId, title, description, commitType: null, createdBy: callerId, cancellationToken); callerId, title, description, commitType: null, createdBy: callerId, model: model, ct: cancellationToken);
await _broadcaster.TaskUpdated(child.Id); await _broadcaster.TaskUpdated(child.Id);
await _broadcaster.TaskUpdated(callerId); await _broadcaster.TaskUpdated(callerId);
return new SuggestedImprovementDto(child.Id); return new SuggestedImprovementDto(child.Id);

View File

@@ -0,0 +1,36 @@
using ClaudeDo.Data.Models;
namespace ClaudeDo.Data.Tests;
public class ModelRegistryTests
{
[Theory]
[InlineData("sonnet", "sonnet")]
[InlineData("OPUS", "opus")]
[InlineData(" haiku ", "haiku")]
public void NormalizeAlias_canonicalizes_known_aliases(string input, string expected)
{
Assert.Equal(expected, ModelRegistry.NormalizeAlias(input));
}
[Theory]
[InlineData(null)]
[InlineData("")]
[InlineData(" ")]
public void NormalizeAlias_blank_means_inherit(string? input)
{
Assert.Null(ModelRegistry.NormalizeAlias(input));
}
[Fact]
public void NormalizeAlias_unknown_throws()
{
Assert.Throws<ArgumentException>(() => ModelRegistry.NormalizeAlias("gpt4"));
}
[Fact]
public void ByCostAscending_is_haiku_sonnet_opus()
{
Assert.Equal(new[] { "haiku", "sonnet", "opus" }, ModelRegistry.ByCostAscending);
}
}

View File

@@ -729,6 +729,42 @@ public sealed class ExternalMcpServiceTests : IDisposable
Assert.Equal(TaskStatus.Done, reloaded!.Status); Assert.Equal(TaskStatus.Done, reloaded!.Status);
} }
// ── AddTask model override ────────────────────────────────────────────────
[Fact]
public async Task AddTask_NoModel_LeavesModelNull()
{
var listId = await SeedListAsync();
var sut = NewService();
var dto = await sut.AddTask(listId, "t", cancellationToken: CancellationToken.None);
var loaded = await _tasks.GetByIdAsync(dto.Id);
Assert.Null(loaded!.Model);
}
[Fact]
public async Task AddTask_PersistsNormalizedModel()
{
var listId = await SeedListAsync();
var sut = NewService();
var dto = await sut.AddTask(listId, "t", model: "HAIKU", cancellationToken: CancellationToken.None);
var loaded = await _tasks.GetByIdAsync(dto.Id);
Assert.Equal("haiku", loaded!.Model);
}
[Fact]
public async Task AddTask_RejectsUnknownModel()
{
var listId = await SeedListAsync();
var sut = NewService();
await Assert.ThrowsAsync<ArgumentException>(
() => sut.AddTask(listId, "t", model: "gpt4", cancellationToken: CancellationToken.None));
}
// ── ContinueTask validation ─────────────────────────────────────────────── // ── ContinueTask validation ───────────────────────────────────────────────
[Fact] [Fact]

View File

@@ -111,8 +111,8 @@ public sealed class PlanningEndToEndTests : IDisposable
// Wire the ambient context so _svc reads the correct parent // Wire the ambient context so _svc reads the correct parent
_httpContext.Items["PlanningContext"] = new PlanningMcpContext { ParentTaskId = parent.Id }; _httpContext.Items["PlanningContext"] = new PlanningMcpContext { ParentTaskId = parent.Id };
await _svc.CreateChildTask("sub 1", null, null, CancellationToken.None); await _svc.CreateChildTask("sub 1", null, null, null, CancellationToken.None);
await _svc.CreateChildTask("sub 2", null, null, CancellationToken.None); await _svc.CreateChildTask("sub 2", null, null, null, CancellationToken.None);
var count = await _svc.Finalize(true, CancellationToken.None); var count = await _svc.Finalize(true, CancellationToken.None);
Assert.Equal(2, count); Assert.Equal(2, count);
@@ -155,9 +155,9 @@ public sealed class PlanningEndToEndTests : IDisposable
await _manager.StartAsync(parent.Id, CancellationToken.None); await _manager.StartAsync(parent.Id, CancellationToken.None);
_httpContext.Items["PlanningContext"] = new PlanningMcpContext { ParentTaskId = parent.Id }; _httpContext.Items["PlanningContext"] = new PlanningMcpContext { ParentTaskId = parent.Id };
await _svc.CreateChildTask("c1", null, null, CancellationToken.None); await _svc.CreateChildTask("c1", null, null, null, CancellationToken.None);
await _svc.CreateChildTask("c2", null, null, CancellationToken.None); await _svc.CreateChildTask("c2", null, null, null, CancellationToken.None);
await _svc.CreateChildTask("c3", null, null, CancellationToken.None); await _svc.CreateChildTask("c3", null, null, null, CancellationToken.None);
var kidsBefore = await _tasks.GetChildrenAsync(parent.Id); var kidsBefore = await _tasks.GetChildrenAsync(parent.Id);
var firstChildId = kidsBefore[0].Id; var firstChildId = kidsBefore[0].Id;

View File

@@ -108,12 +108,35 @@ public sealed class PlanningMcpServiceTests : IDisposable
var parent = await SeedPlanningParentAsync(); var parent = await SeedPlanningParentAsync();
var sut = BuildSut(parent.Id); var sut = BuildSut(parent.Id);
var result = await sut.CreateChildTask("My child", "desc", null, CancellationToken.None); var result = await sut.CreateChildTask("My child", "desc", null, model: null, CancellationToken.None);
Assert.Equal("Idle", result.Status); Assert.Equal("Idle", result.Status);
var child = await _tasks.GetByIdAsync(result.TaskId); var child = await _tasks.GetByIdAsync(result.TaskId);
Assert.Equal("My child", child!.Title); Assert.Equal("My child", child!.Title);
Assert.Equal(TaskStatus.Idle, child.Status); Assert.Equal(TaskStatus.Idle, child.Status);
Assert.Null(child.Model);
}
[Fact]
public async Task CreateChildTask_PersistsNormalizedModel()
{
var parent = await SeedPlanningParentAsync();
var sut = BuildSut(parent.Id);
var result = await sut.CreateChildTask("c", null, null, model: "Opus", CancellationToken.None);
var child = await _tasks.GetByIdAsync(result.TaskId);
Assert.Equal("opus", child!.Model);
}
[Fact]
public async Task CreateChildTask_RejectsUnknownModel()
{
var parent = await SeedPlanningParentAsync();
var sut = BuildSut(parent.Id);
await Assert.ThrowsAsync<ArgumentException>(
() => sut.CreateChildTask("c", null, null, model: "turbo", CancellationToken.None));
} }
[Fact] [Fact]
@@ -244,7 +267,7 @@ public sealed class PlanningMcpServiceTests : IDisposable
var parent = await SeedPlanningParentAsync(); var parent = await SeedPlanningParentAsync();
var sut = BuildSut(parent.Id); var sut = BuildSut(parent.Id);
var result = await sut.CreateChildTask("c", null, null, CancellationToken.None); var result = await sut.CreateChildTask("c", null, null, model: null, CancellationToken.None);
var ids = TaskUpdatedIds(); var ids = TaskUpdatedIds();
Assert.Contains(result.TaskId, ids); Assert.Contains(result.TaskId, ids);

View File

@@ -39,12 +39,36 @@ public sealed class SuggestImprovementTests : IDisposable
using var ctx = _db.CreateContext(); using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"), var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"),
new HubBroadcaster(new FakeHubContext())); new HubBroadcaster(new FakeHubContext()));
var dto = await svc.SuggestImprovement("Refactor X", "details", default); var dto = await svc.SuggestImprovement("Refactor X", "details", model: null, default);
var child = await new TaskRepository(ctx).GetByIdAsync(dto.ChildTaskId); var child = await new TaskRepository(ctx).GetByIdAsync(dto.ChildTaskId);
Assert.Equal("caller", child!.ParentTaskId); Assert.Equal("caller", child!.ParentTaskId);
Assert.Equal("caller", child.CreatedBy); Assert.Equal("caller", child.CreatedBy);
Assert.Equal(TaskStatus.Idle, child.Status); Assert.Equal(TaskStatus.Idle, child.Status);
Assert.Equal("l1", child.ListId); Assert.Equal("l1", child.ListId);
Assert.Null(child.Model);
}
[Fact]
public async Task SuggestImprovement_persists_normalized_model()
{
await SeedCallerAsync("caller", parentId: null);
using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"),
new HubBroadcaster(new FakeHubContext()));
var dto = await svc.SuggestImprovement("Refactor X", "details", model: "HAIKU", default);
var child = await new TaskRepository(ctx).GetByIdAsync(dto.ChildTaskId);
Assert.Equal("haiku", child!.Model);
}
[Fact]
public async Task SuggestImprovement_rejects_unknown_model()
{
await SeedCallerAsync("caller", parentId: null);
using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"),
new HubBroadcaster(new FakeHubContext()));
await Assert.ThrowsAsync<ArgumentException>(
() => svc.SuggestImprovement("x", "y", model: "gpt4", default));
} }
[Fact] [Fact]
@@ -56,6 +80,6 @@ public sealed class SuggestImprovementTests : IDisposable
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("child"), var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("child"),
new HubBroadcaster(new FakeHubContext())); new HubBroadcaster(new FakeHubContext()));
await Assert.ThrowsAsync<InvalidOperationException>( await Assert.ThrowsAsync<InvalidOperationException>(
() => svc.SuggestImprovement("nested", "x", default)); () => svc.SuggestImprovement("nested", "x", model: null, default));
} }
} }