Files
ClaudeDo/tests/ClaudeDo.Worker.Tests/SuggestImprovementTests.cs
Mika Kuns c7f8280106 feat(worker): AskUser MCP tool so a running task can ask the user mid-run
A running task can call mcp__claudedo_run__AskUser(question) to block (up to 3
min) on a human answer. PendingQuestionRegistry holds the pending question +
TaskCompletionSource; the tool broadcasts TaskQuestionAsked, awaits the answer
(WorkerHub.AnswerTaskQuestion resolves it), and returns it as the tool result —
or a 'proceed on your judgment' fallback on timeout. The run stays Running
throughout (no status/schema change). ClaudeProcess raises MCP_TOOL_TIMEOUT so
the 60s HTTP-MCP cap doesn't kill the wait; the run MCP is now wired for every
task, not just standalone ones. System prompt updated to reconcile 'unattended'.
2026-06-26 16:11:51 +02:00

85 lines
3.7 KiB
C#

using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Hub;
using ClaudeDo.Worker.Runner;
using ClaudeDo.Worker.Tests.Infrastructure;
using Microsoft.AspNetCore.Http;
using TaskStatus = ClaudeDo.Data.Models.TaskStatus;
using Xunit;
namespace ClaudeDo.Worker.Tests;
public sealed class SuggestImprovementTests : IDisposable
{
private readonly DbFixture _db = new();
public void Dispose() => _db.Dispose();
private static TaskRunMcpContextAccessor AccessorFor(string callerTaskId)
{
var http = new HttpContextAccessor { HttpContext = new DefaultHttpContext() };
http.HttpContext!.Items["TaskRunContext"] = new TaskRunMcpContext { CallerTaskId = callerTaskId };
return new TaskRunMcpContextAccessor(http);
}
private async Task SeedCallerAsync(string id, string? parentId)
{
using var ctx = _db.CreateContext();
if (!ctx.Lists.Any())
ctx.Lists.Add(new ListEntity { Id = "l1", Name = "L", CreatedAt = DateTime.UtcNow });
ctx.Tasks.Add(new TaskEntity { Id = id, ListId = "l1", Title = "Caller",
Status = TaskStatus.Running, ParentTaskId = parentId, CommitType = "feat", CreatedAt = DateTime.UtcNow });
await ctx.SaveChangesAsync();
}
[Fact]
public async Task SuggestImprovement_stamps_parent_createdBy_status_and_list()
{
await SeedCallerAsync("caller", parentId: null);
using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"),
new HubBroadcaster(new CapturingHubContext()), new PendingQuestionRegistry());
var dto = await svc.SuggestImprovement("Refactor X", "details", model: null, default);
var child = await new TaskRepository(ctx).GetByIdAsync(dto.ChildTaskId);
Assert.Equal("caller", child!.ParentTaskId);
Assert.Equal("caller", child.CreatedBy);
Assert.Equal(TaskStatus.Idle, child.Status);
Assert.Equal("l1", child.ListId);
Assert.Null(child.Model);
}
[Fact]
public async Task SuggestImprovement_persists_normalized_model()
{
await SeedCallerAsync("caller", parentId: null);
using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"),
new HubBroadcaster(new CapturingHubContext()), new PendingQuestionRegistry());
var dto = await svc.SuggestImprovement("Refactor X", "details", model: "HAIKU", default);
var child = await new TaskRepository(ctx).GetByIdAsync(dto.ChildTaskId);
Assert.Equal("haiku", child!.Model);
}
[Fact]
public async Task SuggestImprovement_rejects_unknown_model()
{
await SeedCallerAsync("caller", parentId: null);
using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"),
new HubBroadcaster(new CapturingHubContext()), new PendingQuestionRegistry());
await Assert.ThrowsAsync<ArgumentException>(
() => svc.SuggestImprovement("x", "y", model: "gpt4", default));
}
[Fact]
public async Task SuggestImprovement_rejects_when_caller_is_a_child()
{
await SeedCallerAsync("parent", parentId: null);
await SeedCallerAsync("child", parentId: "parent");
using var ctx = _db.CreateContext();
var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("child"),
new HubBroadcaster(new CapturingHubContext()), new PendingQuestionRegistry());
await Assert.ThrowsAsync<InvalidOperationException>(
() => svc.SuggestImprovement("nested", "x", model: null, default));
}
}