diff --git a/src/ClaudeDo.Worker/Planning/PlanningTokenAuth.cs b/src/ClaudeDo.Worker/Planning/PlanningTokenAuth.cs index ec8a9bd..dd6e153 100644 --- a/src/ClaudeDo.Worker/Planning/PlanningTokenAuth.cs +++ b/src/ClaudeDo.Worker/Planning/PlanningTokenAuth.cs @@ -1,4 +1,5 @@ using ClaudeDo.Data.Repositories; +using ClaudeDo.Worker.Runner; using Microsoft.AspNetCore.Http; namespace ClaudeDo.Worker.Planning; @@ -9,7 +10,7 @@ public sealed class PlanningTokenAuthMiddleware public PlanningTokenAuthMiddleware(RequestDelegate next) => _next = next; - public async Task InvokeAsync(HttpContext ctx, TaskRepository tasks) + public async Task InvokeAsync(HttpContext ctx, TaskRepository tasks, TaskRunTokenRegistry runTokens) { if (!ctx.Request.Path.StartsWithSegments("/mcp")) { @@ -26,15 +27,23 @@ public sealed class PlanningTokenAuthMiddleware } var token = auth.Substring("Bearer ".Length).Trim(); + var parent = await tasks.FindByPlanningTokenAsync(token, ctx.RequestAborted); - if (parent is null || parent.PlanningPhase != ClaudeDo.Data.Models.PlanningPhase.Active) + if (parent is not null && parent.PlanningPhase == ClaudeDo.Data.Models.PlanningPhase.Active) { - ctx.Response.StatusCode = 401; - await ctx.Response.WriteAsync("Invalid or expired planning token"); + ctx.Items["PlanningContext"] = new PlanningMcpContext { ParentTaskId = parent.Id }; + await _next(ctx); return; } - ctx.Items["PlanningContext"] = new PlanningMcpContext { ParentTaskId = parent.Id }; - await _next(ctx); + if (runTokens.TryResolve(token, out var callerTaskId)) + { + ctx.Items["TaskRunContext"] = new TaskRunMcpContext { CallerTaskId = callerTaskId }; + await _next(ctx); + return; + } + + ctx.Response.StatusCode = 401; + await ctx.Response.WriteAsync("Invalid or expired token"); } } diff --git a/src/ClaudeDo.Worker/Program.cs b/src/ClaudeDo.Worker/Program.cs index 68cf03f..3f06a19 100644 --- a/src/ClaudeDo.Worker/Program.cs +++ b/src/ClaudeDo.Worker/Program.cs @@ -56,6 +56,7 @@ builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); @@ -131,6 +132,8 @@ builder.Services.AddSingleton(sp => new WindowsTerminalPlanningLauncher("wt.exe", cfg.ClaudeBin)); builder.Services.AddHttpContextAccessor(); builder.Services.AddScoped(); +builder.Services.AddScoped(); +builder.Services.AddScoped(); builder.Services.AddScoped(sp => sp.GetRequiredService>().CreateDbContext()); builder.Services.AddScoped(); @@ -138,7 +141,8 @@ builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddMcpServer() .WithHttpTransport() - .WithTools(); + .WithTools() + .WithTools(); // Loopback-only bind. Firewall is irrelevant for 127.0.0.1. builder.WebHost.UseUrls($"http://127.0.0.1:{cfg.SignalRPort}"); diff --git a/tests/ClaudeDo.Worker.Tests/Hub/TaskRunTokenAuthTests.cs b/tests/ClaudeDo.Worker.Tests/Hub/TaskRunTokenAuthTests.cs new file mode 100644 index 0000000..04d1c48 --- /dev/null +++ b/tests/ClaudeDo.Worker.Tests/Hub/TaskRunTokenAuthTests.cs @@ -0,0 +1,44 @@ +using ClaudeDo.Data.Repositories; +using ClaudeDo.Worker.Planning; +using ClaudeDo.Worker.Runner; +using ClaudeDo.Worker.Tests.Infrastructure; +using Microsoft.AspNetCore.Http; +using Xunit; + +namespace ClaudeDo.Worker.Tests.Hub; + +public sealed class TaskRunTokenAuthTests : IDisposable +{ + private readonly DbFixture _db = new(); + public void Dispose() => _db.Dispose(); + + [Fact] + public async Task Valid_taskRun_token_populates_TaskRunContext_and_calls_next() + { + var reg = new TaskRunTokenRegistry(); + reg.Register("run-token", "task-1"); + bool nextCalled = false; + var mw = new PlanningTokenAuthMiddleware(_ => { nextCalled = true; return Task.CompletedTask; }); + var ctx = new DefaultHttpContext(); + ctx.Request.Path = "/mcp"; + ctx.Request.Headers["Authorization"] = "Bearer run-token"; + using var db = _db.CreateContext(); + await mw.InvokeAsync(ctx, new TaskRepository(db), reg); + Assert.True(nextCalled); + var resolved = ctx.Items["TaskRunContext"] as TaskRunMcpContext; + Assert.NotNull(resolved); + Assert.Equal("task-1", resolved!.CallerTaskId); + } + + [Fact] + public async Task Unknown_token_returns_401() + { + var mw = new PlanningTokenAuthMiddleware(_ => Task.CompletedTask); + var ctx = new DefaultHttpContext(); + ctx.Request.Path = "/mcp"; + ctx.Request.Headers["Authorization"] = "Bearer nope"; + using var db = _db.CreateContext(); + await mw.InvokeAsync(ctx, new TaskRepository(db), new TaskRunTokenRegistry()); + Assert.Equal(401, ctx.Response.StatusCode); + } +}