From c7f8280106d8d4488dbba045a067eba0ad66d482 Mon Sep 17 00:00:00 2001 From: Mika Kuns Date: Thu, 25 Jun 2026 22:53:34 +0200 Subject: [PATCH] feat(worker): AskUser MCP tool so a running task can ask the user mid-run MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A running task can call mcp__claudedo_run__AskUser(question) to block (up to 3 min) on a human answer. PendingQuestionRegistry holds the pending question + TaskCompletionSource; the tool broadcasts TaskQuestionAsked, awaits the answer (WorkerHub.AnswerTaskQuestion resolves it), and returns it as the tool result — or a 'proceed on your judgment' fallback on timeout. The run stays Running throughout (no status/schema change). ClaudeProcess raises MCP_TOOL_TIMEOUT so the 60s HTTP-MCP cap doesn't kill the wait; the run MCP is now wired for every task, not just standalone ones. System prompt updated to reconcile 'unattended'. --- src/ClaudeDo.Data/PromptFiles.cs | 9 +- src/ClaudeDo.Worker/Hub/HubBroadcaster.cs | 6 + src/ClaudeDo.Worker/Hub/WorkerHub.cs | 16 +++ src/ClaudeDo.Worker/Program.cs | 1 + src/ClaudeDo.Worker/Runner/ClaudeProcess.cs | 5 + .../Runner/PendingQuestionRegistry.cs | 51 +++++++ .../Runner/TaskRunMcpService.cs | 71 +++++++++- src/ClaudeDo.Worker/Runner/TaskRunner.cs | 30 +++-- .../Hub/ClearMyDayHubTests.cs | 3 +- .../Hub/OnlineInboxHubTests.cs | 2 +- .../Hub/PlanningHubTests.cs | 3 +- .../Hub/WorktreeStateHubTests.cs | 3 +- .../Runner/PendingQuestionTests.cs | 126 ++++++++++++++++++ .../SuggestImprovementTests.cs | 8 +- 14 files changed, 308 insertions(+), 26 deletions(-) create mode 100644 src/ClaudeDo.Worker/Runner/PendingQuestionRegistry.cs create mode 100644 tests/ClaudeDo.Worker.Tests/Runner/PendingQuestionTests.cs diff --git a/src/ClaudeDo.Data/PromptFiles.cs b/src/ClaudeDo.Data/PromptFiles.cs index 838e093..35b9a5c 100644 --- a/src/ClaudeDo.Data/PromptFiles.cs +++ b/src/ClaudeDo.Data/PromptFiles.cs @@ -105,9 +105,12 @@ public static class PromptFiles - Don't introduce injection/XSS/secret-leak issues. Never commit credentials. ## You are running unattended - You run autonomously with no human watching. There is no one to answer mid-task - questions, so never stop to ask — make the most reasonable decision, note the - assumption, and continue. + You run autonomously, usually with no one watching. Default to making the most + reasonable decision yourself, noting the assumption, and continuing — do not stop + for routine choices. The one exception: at a genuine fork where a wrong guess + would be costly or hard to undo (an irreversible action, contradictory + requirements), you may call AskUser(question) to ask the user and wait briefly for + an answer. If no one responds in time, proceed on your best judgment. ## When you are blocked If something genuinely prevents you from completing part of the task (missing diff --git a/src/ClaudeDo.Worker/Hub/HubBroadcaster.cs b/src/ClaudeDo.Worker/Hub/HubBroadcaster.cs index c0690be..0bbb8f8 100644 --- a/src/ClaudeDo.Worker/Hub/HubBroadcaster.cs +++ b/src/ClaudeDo.Worker/Hub/HubBroadcaster.cs @@ -26,6 +26,12 @@ public sealed class HubBroadcaster : IPrimeBroadcaster, IRefineBroadcaster public Task TaskUpdated(string taskId) => _hub.Clients.All.SendAsync("TaskUpdated", taskId); + public Task TaskQuestionAsked(string taskId, string questionId, string question) => + _hub.Clients.All.SendAsync("TaskQuestionAsked", taskId, questionId, question); + + public Task TaskQuestionResolved(string taskId, string questionId) => + _hub.Clients.All.SendAsync("TaskQuestionResolved", taskId, questionId); + public Task ListUpdated(string listId) => _hub.Clients.All.SendAsync("ListUpdated", listId); diff --git a/src/ClaudeDo.Worker/Hub/WorkerHub.cs b/src/ClaudeDo.Worker/Hub/WorkerHub.cs index 6f897c7..2864fb8 100644 --- a/src/ClaudeDo.Worker/Hub/WorkerHub.cs +++ b/src/ClaudeDo.Worker/Hub/WorkerHub.cs @@ -56,6 +56,7 @@ public record WorktreeOverviewDto( bool PathExistsOnDisk); public record ForceRemoveResultDto(bool Removed, string? Reason); +public record PendingQuestionDto(string TaskId, string QuestionId, string Question); public record MergeResultDto(string Status, IReadOnlyList ConflictFiles, string? ErrorMessage); public record MergePreviewDto(string Status, IReadOnlyList ConflictFiles, int ChangedFileCount); public record MergeTargetsDto(string DefaultBranch, IReadOnlyList LocalBranches); @@ -114,6 +115,7 @@ public sealed class WorkerHub : Microsoft.AspNetCore.SignalR.Hub private readonly WorkerConfig _cfg; private readonly OnlineInboxConfig _onlineInboxConfig; private readonly OnlineTokenStore _onlineTokenStore; + private readonly Runner.PendingQuestionRegistry _pendingQuestions; private readonly LogRingBuffer? _logBuffer; public WorkerHub( @@ -139,6 +141,7 @@ public sealed class WorkerHub : Microsoft.AspNetCore.SignalR.Hub WorkerConfig cfg, OnlineInboxConfig onlineInboxConfig, OnlineTokenStore onlineTokenStore, + Runner.PendingQuestionRegistry pendingQuestions, LogRingBuffer? logBuffer = null) { _queue = queue; @@ -163,9 +166,22 @@ public sealed class WorkerHub : Microsoft.AspNetCore.SignalR.Hub _cfg = cfg; _onlineInboxConfig = onlineInboxConfig; _onlineTokenStore = onlineTokenStore; + _pendingQuestions = pendingQuestions; _logBuffer = logBuffer; } + /// Deliver the user's answer to a question a running task raised via AskUser. + /// Returns false if no matching question is still pending (already answered or timed out). + public bool AnswerTaskQuestion(string taskId, string questionId, string answer) => + _pendingQuestions.TryAnswer(taskId, questionId, answer ?? string.Empty); + + /// The question a running task is currently blocked on, if any (for UI re-attach). + public PendingQuestionDto? GetPendingQuestion(string taskId) + { + var q = _pendingQuestions.Get(taskId); + return q is null ? null : new PendingQuestionDto(q.TaskId, q.QuestionId, q.Question); + } + /// Recent worker log records (last 30 min, all levels) for the Log Visualizer overlay. public IReadOnlyList GetRecentLogs() => _logBuffer?.Snapshot() ?? Array.Empty(); diff --git a/src/ClaudeDo.Worker/Program.cs b/src/ClaudeDo.Worker/Program.cs index aab7c5c..3c4270e 100644 --- a/src/ClaudeDo.Worker/Program.cs +++ b/src/ClaudeDo.Worker/Program.cs @@ -72,6 +72,7 @@ builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); diff --git a/src/ClaudeDo.Worker/Runner/ClaudeProcess.cs b/src/ClaudeDo.Worker/Runner/ClaudeProcess.cs index 6d383f5..9f03d86 100644 --- a/src/ClaudeDo.Worker/Runner/ClaudeProcess.cs +++ b/src/ClaudeDo.Worker/Runner/ClaudeProcess.cs @@ -39,6 +39,11 @@ public sealed class ClaudeProcess : IClaudeProcess foreach (var arg in arguments) psi.ArgumentList.Add(arg); + // Claude Code caps HTTP MCP tool calls at 60 s unless MCP_TOOL_TIMEOUT is raised. + // The in-task AskUser tool blocks up to 3 min waiting for the user, so lift the cap + // (with margin) or that wait would be killed early. Harmless for every other tool. + psi.Environment["MCP_TOOL_TIMEOUT"] = "200000"; + using var process = new Process { StartInfo = psi }; process.Start(); diff --git a/src/ClaudeDo.Worker/Runner/PendingQuestionRegistry.cs b/src/ClaudeDo.Worker/Runner/PendingQuestionRegistry.cs new file mode 100644 index 0000000..c5328bb --- /dev/null +++ b/src/ClaudeDo.Worker/Runner/PendingQuestionRegistry.cs @@ -0,0 +1,51 @@ +using System.Collections.Concurrent; + +namespace ClaudeDo.Worker.Runner; + +public sealed record PendingQuestion(string TaskId, string QuestionId, string Question); + +// In-memory store of questions a running task has raised via the AskUser MCP tool and is +// blocking on. One pending question per task (the run's process is blocked mid-tool-call, +// so it cannot ask twice at once). Kept out of the DB on purpose: a question that outlives +// a Worker restart is already dead (StaleTaskRecovery flips the run to Failed). Singleton. +public sealed class PendingQuestionRegistry +{ + private readonly ConcurrentDictionary _byTask = new(); + + private sealed record Entry(string QuestionId, string Question, TaskCompletionSource Answer); + + // Registers a question for the task and returns its id plus the awaitable answer. + // A second register for the same task replaces any stale entry. + public (string QuestionId, Task Answer) Register(string taskId, string question) + { + var questionId = Guid.NewGuid().ToString("N"); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _byTask[taskId] = new Entry(questionId, question, tcs); + return (questionId, tcs.Task); + } + + // Delivers the answer to a waiting question. Returns false if no matching question is + // pending (already answered, timed out, or stale id). + public bool TryAnswer(string taskId, string questionId, string answer) + { + if (_byTask.TryGetValue(taskId, out var entry) + && entry.QuestionId == questionId + && _byTask.TryRemove(taskId, out _)) + { + return entry.Answer.TrySetResult(answer ?? string.Empty); + } + return false; + } + + public PendingQuestion? Get(string taskId) => + _byTask.TryGetValue(taskId, out var entry) + ? new PendingQuestion(taskId, entry.QuestionId, entry.Question) + : null; + + // Drops a pending question without delivering an answer (timeout/cancel cleanup). + public void Remove(string taskId, string questionId) + { + if (_byTask.TryGetValue(taskId, out var entry) && entry.QuestionId == questionId) + _byTask.TryRemove(taskId, out _); + } +} diff --git a/src/ClaudeDo.Worker/Runner/TaskRunMcpService.cs b/src/ClaudeDo.Worker/Runner/TaskRunMcpService.cs index 4293b5f..ccb2e72 100644 --- a/src/ClaudeDo.Worker/Runner/TaskRunMcpService.cs +++ b/src/ClaudeDo.Worker/Runner/TaskRunMcpService.cs @@ -13,12 +13,28 @@ public sealed class TaskRunMcpService private readonly TaskRepository _tasks; private readonly TaskRunMcpContextAccessor _ctx; private readonly HubBroadcaster _broadcaster; + private readonly PendingQuestionRegistry _pending; - public TaskRunMcpService(TaskRepository tasks, TaskRunMcpContextAccessor ctx, HubBroadcaster broadcaster) + // How long a running task blocks waiting for the user to answer an AskUser question + // before it gives up and proceeds autonomously. NOTE: the spawned claude process must + // run with MCP_TOOL_TIMEOUT raised above this (ClaudeProcess sets it) — Claude Code + // otherwise caps HTTP MCP tool calls at 60 s and would kill the call early. + internal static readonly TimeSpan QuestionWindow = TimeSpan.FromMinutes(3); + + internal const string TimeoutFallback = + "No response received within 3 minutes — proceed using your best judgment, " + + "note the assumption you made, and continue."; + + public TaskRunMcpService( + TaskRepository tasks, + TaskRunMcpContextAccessor ctx, + HubBroadcaster broadcaster, + PendingQuestionRegistry pending) { _tasks = tasks; _ctx = ctx; _broadcaster = broadcaster; + _pending = pending; } [McpServerTool, Description( @@ -47,4 +63,57 @@ public sealed class TaskRunMcpService await _broadcaster.TaskUpdated(callerId); return new SuggestedImprovementDto(child.Id); } + + [McpServerTool, Description( + "Ask the user a question and wait up to 3 minutes for their answer. Use this ONLY " + + "when you genuinely need a human decision to proceed correctly and a wrong guess " + + "would be costly or hard to undo (an irreversible action, contradictory " + + "requirements, or a real fork where both options have meaningful consequences). " + + "Do NOT use it for routine choices you can reasonably make yourself — for those, " + + "pick the most sensible option and continue. The returned string is the user's " + + "answer; if no one responds in time it tells you to proceed on your own judgment.")] + public Task AskUser(string question, CancellationToken cancellationToken) + { + var callerId = _ctx.Current.CallerTaskId; + return AwaitAnswerAsync( + _pending, callerId, question ?? string.Empty, + (questionId, q) => _broadcaster.TaskQuestionAsked(callerId, questionId, q), + questionId => _broadcaster.TaskQuestionResolved(callerId, questionId), + QuestionWindow, cancellationToken); + } + + // Registers the question, signals it, and blocks until the user answers, the window + // elapses (returns the fallback), or the run is cancelled (rethrows). Pure except for + // the two broadcast callbacks, so it is unit-testable without a real hub or DbContext. + internal static async Task AwaitAnswerAsync( + PendingQuestionRegistry pending, + string taskId, + string question, + Func onAsked, + Func onResolved, + TimeSpan timeout, + CancellationToken ct) + { + var (questionId, answer) = pending.Register(taskId, question); + await onAsked(questionId, question); + try + { + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + timeoutCts.CancelAfter(timeout); + try + { + return await answer.WaitAsync(timeoutCts.Token); + } + catch (OperationCanceledException) when (!ct.IsCancellationRequested) + { + // The window elapsed (not a run cancellation) — let the agent carry on. + return TimeoutFallback; + } + } + finally + { + pending.Remove(taskId, questionId); + await onResolved(questionId); + } + } } diff --git a/src/ClaudeDo.Worker/Runner/TaskRunner.cs b/src/ClaudeDo.Worker/Runner/TaskRunner.cs index 1b9721e..f86a20e 100644 --- a/src/ClaudeDo.Worker/Runner/TaskRunner.cs +++ b/src/ClaudeDo.Worker/Runner/TaskRunner.cs @@ -90,21 +90,23 @@ public sealed class TaskRunner var resolvedConfig = await ResolveConfigAsync(task, listConfig, null, ct); - // Improvement-eligible runs get a per-run MCP identity so the agent can file - // out-of-scope follow-ups via SuggestImprovement. Children and planning runs do not. - if (task.ParentTaskId is null && task.PlanningPhase == PlanningPhase.None) + // Every run gets a per-run MCP identity so the agent can ask the user a + // mid-run question via AskUser. Improvement-eligible (standalone top-level) + // runs additionally get SuggestImprovement for filing out-of-scope follow-ups. + mcpToken = TaskRunTokenRegistry.GenerateToken(); + _tokens.Register(mcpToken, task.Id); + Directory.CreateDirectory(_cfg.LogRoot); + mcpConfigPath = Path.Combine(_cfg.LogRoot, $"{task.Id}_mcp.json"); + await File.WriteAllTextAsync(mcpConfigPath, BuildRunMcpConfigJson(mcpToken), ct); + + var improvementEligible = task.ParentTaskId is null && task.PlanningPhase == PlanningPhase.None; + resolvedConfig = resolvedConfig with { - mcpToken = TaskRunTokenRegistry.GenerateToken(); - _tokens.Register(mcpToken, task.Id); - Directory.CreateDirectory(_cfg.LogRoot); - mcpConfigPath = Path.Combine(_cfg.LogRoot, $"{task.Id}_mcp.json"); - await File.WriteAllTextAsync(mcpConfigPath, BuildRunMcpConfigJson(mcpToken), ct); - resolvedConfig = resolvedConfig with - { - McpConfigPath = mcpConfigPath, - AllowedTools = "mcp__claudedo_run__SuggestImprovement", - }; - } + McpConfigPath = mcpConfigPath, + AllowedTools = improvementEligible + ? "mcp__claudedo_run__AskUser,mcp__claudedo_run__SuggestImprovement" + : "mcp__claudedo_run__AskUser", + }; var now = DateTime.UtcNow; // The queue picker claims Queued→Running atomically (incl. StartedAt) before diff --git a/tests/ClaudeDo.Worker.Tests/Hub/ClearMyDayHubTests.cs b/tests/ClaudeDo.Worker.Tests/Hub/ClearMyDayHubTests.cs index 31e1074..fbbd52a 100644 --- a/tests/ClaudeDo.Worker.Tests/Hub/ClearMyDayHubTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Hub/ClearMyDayHubTests.cs @@ -20,7 +20,8 @@ public sealed class ClearMyDayHubTests : IDisposable var hub = new WorkerHub( null!, null!, null!, null!, broadcaster, _db.CreateFactory(), null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, - null!, new ClaudeDo.Worker.Online.OnlineInboxConfig(), new ClaudeDo.Worker.Online.OnlineTokenStore()); + null!, new ClaudeDo.Worker.Online.OnlineInboxConfig(), new ClaudeDo.Worker.Online.OnlineTokenStore(), + new ClaudeDo.Worker.Runner.PendingQuestionRegistry()); hub.Clients = new FakeHubCallerClients(new RecordingClientProxy()); hub.Context = new FakeHubCallerContext(); return hub; diff --git a/tests/ClaudeDo.Worker.Tests/Hub/OnlineInboxHubTests.cs b/tests/ClaudeDo.Worker.Tests/Hub/OnlineInboxHubTests.cs index 0ab74d6..9d7c406 100644 --- a/tests/ClaudeDo.Worker.Tests/Hub/OnlineInboxHubTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Hub/OnlineInboxHubTests.cs @@ -31,7 +31,7 @@ public sealed class OnlineInboxHubTests : IDisposable var hub = new WorkerHub( null!, null!, null!, null!, broadcaster, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, - cfg, inboxCfg, store); + cfg, inboxCfg, store, new ClaudeDo.Worker.Runner.PendingQuestionRegistry()); hub.Clients = new FakeHubCallerClients(new RecordingClientProxy()); hub.Context = new FakeHubCallerContext(); return (hub, inboxCfg, store); diff --git a/tests/ClaudeDo.Worker.Tests/Hub/PlanningHubTests.cs b/tests/ClaudeDo.Worker.Tests/Hub/PlanningHubTests.cs index ee97f21..eca57bd 100644 --- a/tests/ClaudeDo.Worker.Tests/Hub/PlanningHubTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Hub/PlanningHubTests.cs @@ -56,7 +56,8 @@ public sealed class PlanningHubTests : IDisposable var hub = new WorkerHub( null!, null!, null!, null!, null!, null!, null!, null!, null!, _planning, _launcher, null!, null!, null!, null!, null!, null!, null!, null!, - null!, new ClaudeDo.Worker.Online.OnlineInboxConfig(), new ClaudeDo.Worker.Online.OnlineTokenStore()); + null!, new ClaudeDo.Worker.Online.OnlineInboxConfig(), new ClaudeDo.Worker.Online.OnlineTokenStore(), + new ClaudeDo.Worker.Runner.PendingQuestionRegistry()); hub.Clients = new FakeHubCallerClients(_proxy); hub.Context = new FakeHubCallerContext(); return hub; diff --git a/tests/ClaudeDo.Worker.Tests/Hub/WorktreeStateHubTests.cs b/tests/ClaudeDo.Worker.Tests/Hub/WorktreeStateHubTests.cs index 5f95f1b..a46be04 100644 --- a/tests/ClaudeDo.Worker.Tests/Hub/WorktreeStateHubTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Hub/WorktreeStateHubTests.cs @@ -20,7 +20,8 @@ public sealed class WorktreeStateHubTests : IDisposable var hub = new WorkerHub( null!, null!, null!, null!, broadcaster, _db.CreateFactory(), null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, null!, - null!, new ClaudeDo.Worker.Online.OnlineInboxConfig(), new ClaudeDo.Worker.Online.OnlineTokenStore()); + null!, new ClaudeDo.Worker.Online.OnlineInboxConfig(), new ClaudeDo.Worker.Online.OnlineTokenStore(), + new ClaudeDo.Worker.Runner.PendingQuestionRegistry()); hub.Clients = new FakeHubCallerClients(new RecordingClientProxy()); hub.Context = new FakeHubCallerContext(); return hub; diff --git a/tests/ClaudeDo.Worker.Tests/Runner/PendingQuestionTests.cs b/tests/ClaudeDo.Worker.Tests/Runner/PendingQuestionTests.cs new file mode 100644 index 0000000..22e98bb --- /dev/null +++ b/tests/ClaudeDo.Worker.Tests/Runner/PendingQuestionTests.cs @@ -0,0 +1,126 @@ +using ClaudeDo.Worker.Runner; +using Xunit; + +namespace ClaudeDo.Worker.Tests.Runner; + +public class PendingQuestionRegistryTests +{ + [Fact] + public async Task Register_ThenAnswer_CompletesTheWait() + { + var registry = new PendingQuestionRegistry(); + var (questionId, answer) = registry.Register("t1", "which?"); + + Assert.False(answer.IsCompleted); + Assert.Equal("which?", registry.Get("t1")?.Question); + + Assert.True(registry.TryAnswer("t1", questionId, "this one")); + Assert.Equal("this one", await answer); + Assert.Null(registry.Get("t1")); // cleared after answering + } + + [Fact] + public void TryAnswer_WrongQuestionId_DoesNothing() + { + var registry = new PendingQuestionRegistry(); + var (_, answer) = registry.Register("t1", "q?"); + + Assert.False(registry.TryAnswer("t1", "stale-id", "x")); + Assert.False(answer.IsCompleted); + Assert.NotNull(registry.Get("t1")); + } + + [Fact] + public void TryAnswer_UnknownTask_ReturnsFalse() + { + var registry = new PendingQuestionRegistry(); + Assert.False(registry.TryAnswer("ghost", "q", "x")); + } + + [Fact] + public void SecondRegister_OverwritesStaleEntry() + { + var registry = new PendingQuestionRegistry(); + var (firstId, _) = registry.Register("t1", "first"); + var (secondId, _) = registry.Register("t1", "second"); + + Assert.NotEqual(firstId, secondId); + Assert.Equal("second", registry.Get("t1")?.Question); + Assert.False(registry.TryAnswer("t1", firstId, "x")); // old id no longer valid + Assert.True(registry.TryAnswer("t1", secondId, "ok")); + } + + [Fact] + public void Remove_DropsPendingWithoutAnswering() + { + var registry = new PendingQuestionRegistry(); + var (questionId, answer) = registry.Register("t1", "q?"); + + registry.Remove("t1", questionId); + + Assert.Null(registry.Get("t1")); + Assert.False(answer.IsCompleted); + } +} + +public class AskUserWaitTests +{ + [Fact] + public async Task AwaitAnswer_ReturnsUserAnswer_WhenAnsweredInTime() + { + var registry = new PendingQuestionRegistry(); + string? askedQuestionId = null; + var asked = new TaskCompletionSource(); + var resolved = 0; + + var wait = TaskRunMcpService.AwaitAnswerAsync( + registry, "t1", "DPAPI or plaintext?", + onAsked: (qid, _) => { askedQuestionId = qid; asked.TrySetResult(); return Task.CompletedTask; }, + onResolved: _ => { resolved++; return Task.CompletedTask; }, + timeout: TimeSpan.FromSeconds(5), + ct: CancellationToken.None); + + await asked.Task; // registration + onAsked have run + Assert.True(registry.TryAnswer("t1", askedQuestionId!, "DPAPI please")); + + Assert.Equal("DPAPI please", await wait); + Assert.Equal(1, resolved); + Assert.Null(registry.Get("t1")); + } + + [Fact] + public async Task AwaitAnswer_ReturnsFallback_OnTimeout() + { + var registry = new PendingQuestionRegistry(); + var resolved = 0; + + var result = await TaskRunMcpService.AwaitAnswerAsync( + registry, "t2", "q?", + onAsked: (_, _) => Task.CompletedTask, + onResolved: _ => { resolved++; return Task.CompletedTask; }, + timeout: TimeSpan.FromMilliseconds(40), + ct: CancellationToken.None); + + Assert.Equal(TaskRunMcpService.TimeoutFallback, result); + Assert.Equal(1, resolved); + Assert.Null(registry.Get("t2")); // cleaned up after timeout + } + + [Fact] + public async Task AwaitAnswer_Rethrows_WhenRunCancelled() + { + var registry = new PendingQuestionRegistry(); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(() => + TaskRunMcpService.AwaitAnswerAsync( + registry, "t3", "q?", + onAsked: (_, _) => Task.CompletedTask, + onResolved: _ => Task.CompletedTask, + timeout: TimeSpan.FromMinutes(1), + ct: cts.Token)); + + Assert.Null(registry.Get("t3")); // cleanup still ran + } +} diff --git a/tests/ClaudeDo.Worker.Tests/SuggestImprovementTests.cs b/tests/ClaudeDo.Worker.Tests/SuggestImprovementTests.cs index 3a31439..c8e6854 100644 --- a/tests/ClaudeDo.Worker.Tests/SuggestImprovementTests.cs +++ b/tests/ClaudeDo.Worker.Tests/SuggestImprovementTests.cs @@ -37,7 +37,7 @@ public sealed class SuggestImprovementTests : IDisposable await SeedCallerAsync("caller", parentId: null); using var ctx = _db.CreateContext(); var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"), - new HubBroadcaster(new CapturingHubContext())); + new HubBroadcaster(new CapturingHubContext()), new PendingQuestionRegistry()); var dto = await svc.SuggestImprovement("Refactor X", "details", model: null, default); var child = await new TaskRepository(ctx).GetByIdAsync(dto.ChildTaskId); Assert.Equal("caller", child!.ParentTaskId); @@ -53,7 +53,7 @@ public sealed class SuggestImprovementTests : IDisposable await SeedCallerAsync("caller", parentId: null); using var ctx = _db.CreateContext(); var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"), - new HubBroadcaster(new CapturingHubContext())); + new HubBroadcaster(new CapturingHubContext()), new PendingQuestionRegistry()); var dto = await svc.SuggestImprovement("Refactor X", "details", model: "HAIKU", default); var child = await new TaskRepository(ctx).GetByIdAsync(dto.ChildTaskId); Assert.Equal("haiku", child!.Model); @@ -65,7 +65,7 @@ public sealed class SuggestImprovementTests : IDisposable await SeedCallerAsync("caller", parentId: null); using var ctx = _db.CreateContext(); var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("caller"), - new HubBroadcaster(new CapturingHubContext())); + new HubBroadcaster(new CapturingHubContext()), new PendingQuestionRegistry()); await Assert.ThrowsAsync( () => svc.SuggestImprovement("x", "y", model: "gpt4", default)); } @@ -77,7 +77,7 @@ public sealed class SuggestImprovementTests : IDisposable await SeedCallerAsync("child", parentId: "parent"); using var ctx = _db.CreateContext(); var svc = new TaskRunMcpService(new TaskRepository(ctx), AccessorFor("child"), - new HubBroadcaster(new CapturingHubContext())); + new HubBroadcaster(new CapturingHubContext()), new PendingQuestionRegistry()); await Assert.ThrowsAsync( () => svc.SuggestImprovement("nested", "x", model: null, default)); }