feat(worker): let Claude set the cheapest model per generated task via MCP

AddTask, planning CreateChildTask, and SuggestImprovement now accept an
optional alias-validated model (haiku/sonnet/opus; blank = inherit) so the
model is chosen at creation time instead of a follow-up set_task_config call.
The planning, system, and improvement prompts instruct Claude to pick the
cheapest capable model (haiku < sonnet < opus).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
mika kuns
2026-06-09 22:22:17 +02:00
parent 1448794748
commit c27a179d2b
12 changed files with 181 additions and 18 deletions

View File

@@ -0,0 +1,36 @@
using ClaudeDo.Data.Models;
namespace ClaudeDo.Data.Tests;
public class ModelRegistryTests
{
[Theory]
[InlineData("sonnet", "sonnet")]
[InlineData("OPUS", "opus")]
[InlineData(" haiku ", "haiku")]
public void NormalizeAlias_canonicalizes_known_aliases(string input, string expected)
{
Assert.Equal(expected, ModelRegistry.NormalizeAlias(input));
}
[Theory]
[InlineData(null)]
[InlineData("")]
[InlineData(" ")]
public void NormalizeAlias_blank_means_inherit(string? input)
{
Assert.Null(ModelRegistry.NormalizeAlias(input));
}
[Fact]
public void NormalizeAlias_unknown_throws()
{
Assert.Throws<ArgumentException>(() => ModelRegistry.NormalizeAlias("gpt4"));
}
[Fact]
public void ByCostAscending_is_haiku_sonnet_opus()
{
Assert.Equal(new[] { "haiku", "sonnet", "opus" }, ModelRegistry.ByCostAscending);
}
}

View File

@@ -729,6 +729,42 @@ public sealed class ExternalMcpServiceTests : IDisposable
Assert.Equal(TaskStatus.Done, reloaded!.Status);
}
// ── AddTask model override ────────────────────────────────────────────────
[Fact]
public async Task AddTask_NoModel_LeavesModelNull()
{
var listId = await SeedListAsync();
var sut = NewService();
var dto = await sut.AddTask(listId, "t", cancellationToken: CancellationToken.None);
var loaded = await _tasks.GetByIdAsync(dto.Id);
Assert.Null(loaded!.Model);
}
[Fact]
public async Task AddTask_PersistsNormalizedModel()
{
var listId = await SeedListAsync();
var sut = NewService();
var dto = await sut.AddTask(listId, "t", model: "HAIKU", cancellationToken: CancellationToken.None);
var loaded = await _tasks.GetByIdAsync(dto.Id);
Assert.Equal("haiku", loaded!.Model);
}
[Fact]
public async Task AddTask_RejectsUnknownModel()
{
var listId = await SeedListAsync();
var sut = NewService();
await Assert.ThrowsAsync<ArgumentException>(
() => sut.AddTask(listId, "t", model: "gpt4", cancellationToken: CancellationToken.None));
}
// ── ContinueTask validation ───────────────────────────────────────────────
[Fact]

View File

@@ -111,8 +111,8 @@ public sealed class PlanningEndToEndTests : IDisposable
// Wire the ambient context so _svc reads the correct parent
_httpContext.Items["PlanningContext"] = new PlanningMcpContext { ParentTaskId = parent.Id };
await _svc.CreateChildTask("sub 1", null, null, CancellationToken.None);
await _svc.CreateChildTask("sub 2", null, null, CancellationToken.None);
await _svc.CreateChildTask("sub 1", null, null, null, CancellationToken.None);
await _svc.CreateChildTask("sub 2", null, null, null, CancellationToken.None);
var count = await _svc.Finalize(true, CancellationToken.None);
Assert.Equal(2, count);
@@ -155,9 +155,9 @@ public sealed class PlanningEndToEndTests : IDisposable
await _manager.StartAsync(parent.Id, CancellationToken.None);
_httpContext.Items["PlanningContext"] = new PlanningMcpContext { ParentTaskId = parent.Id };
await _svc.CreateChildTask("c1", null, null, CancellationToken.None);
await _svc.CreateChildTask("c2", null, null, CancellationToken.None);
await _svc.CreateChildTask("c3", null, null, CancellationToken.None);
await _svc.CreateChildTask("c1", null, null, null, CancellationToken.None);
await _svc.CreateChildTask("c2", null, null, null, CancellationToken.None);
await _svc.CreateChildTask("c3", null, null, null, CancellationToken.None);
var kidsBefore = await _tasks.GetChildrenAsync(parent.Id);
var firstChildId = kidsBefore[0].Id;

View File

@@ -108,12 +108,35 @@ public sealed class PlanningMcpServiceTests : IDisposable
var parent = await SeedPlanningParentAsync();
var sut = BuildSut(parent.Id);
var result = await sut.CreateChildTask("My child", "desc", null, CancellationToken.None);
var result = await sut.CreateChildTask("My child", "desc", null, model: null, CancellationToken.None);
Assert.Equal("Idle", result.Status);
var child = await _tasks.GetByIdAsync(result.TaskId);
Assert.Equal("My child", child!.Title);
Assert.Equal(TaskStatus.Idle, child.Status);
Assert.Null(child.Model);
}
[Fact]
public async Task CreateChildTask_PersistsNormalizedModel()
{
var parent = await SeedPlanningParentAsync();
var sut = BuildSut(parent.Id);
var result = await sut.CreateChildTask("c", null, null, model: "Opus", CancellationToken.None);
var child = await _tasks.GetByIdAsync(result.TaskId);
Assert.Equal("opus", child!.Model);
}
[Fact]
public async Task CreateChildTask_RejectsUnknownModel()
{
var parent = await SeedPlanningParentAsync();
var sut = BuildSut(parent.Id);
await Assert.ThrowsAsync<ArgumentException>(
() => sut.CreateChildTask("c", null, null, model: "turbo", CancellationToken.None));
}
[Fact]
@@ -244,7 +267,7 @@ public sealed class PlanningMcpServiceTests : IDisposable
var parent = await SeedPlanningParentAsync();
var sut = BuildSut(parent.Id);
var result = await sut.CreateChildTask("c", null, null, CancellationToken.None);
var result = await sut.CreateChildTask("c", null, null, model: null, CancellationToken.None);
var ids = TaskUpdatedIds();
Assert.Contains(result.TaskId, ids);

View File

@@ -39,12 +39,36 @@ public sealed class SuggestImprovementTests : IDisposable
using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"),
new HubBroadcaster(new FakeHubContext()));
var dto = await svc.SuggestImprovement("Refactor X", "details", default);
var dto = await svc.SuggestImprovement("Refactor X", "details", model: null, default);
var child = await new TaskRepository(ctx).GetByIdAsync(dto.ChildTaskId);
Assert.Equal("caller", child!.ParentTaskId);
Assert.Equal("caller", child.CreatedBy);
Assert.Equal(TaskStatus.Idle, child.Status);
Assert.Equal("l1", child.ListId);
Assert.Null(child.Model);
}
[Fact]
public async Task SuggestImprovement_persists_normalized_model()
{
await SeedCallerAsync("caller", parentId: null);
using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"),
new HubBroadcaster(new FakeHubContext()));
var dto = await svc.SuggestImprovement("Refactor X", "details", model: "HAIKU", default);
var child = await new TaskRepository(ctx).GetByIdAsync(dto.ChildTaskId);
Assert.Equal("haiku", child!.Model);
}
[Fact]
public async Task SuggestImprovement_rejects_unknown_model()
{
await SeedCallerAsync("caller", parentId: null);
using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"),
new HubBroadcaster(new FakeHubContext()));
await Assert.ThrowsAsync<ArgumentException>(
() => svc.SuggestImprovement("x", "y", model: "gpt4", default));
}
[Fact]
@@ -56,6 +80,6 @@ public sealed class SuggestImprovementTests : IDisposable
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("child"),
new HubBroadcaster(new FakeHubContext()));
await Assert.ThrowsAsync<InvalidOperationException>(
() => svc.SuggestImprovement("nested", "x", default));
() => svc.SuggestImprovement("nested", "x", model: null, default));
}
}