Files
ClaudeDo/src/ClaudeDo.Worker/Runner/TaskRunner.cs

404 lines
16 KiB
C#

using ClaudeDo.Data;
using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Config;
using ClaudeDo.Worker.Hub;
using Microsoft.EntityFrameworkCore;
namespace ClaudeDo.Worker.Runner;
public sealed class TaskRunner
{
private readonly IClaudeProcess _claude;
private readonly IDbContextFactory<ClaudeDoDbContext> _dbFactory;
private readonly HubBroadcaster _broadcaster;
private readonly WorktreeManager _wtManager;
private readonly ClaudeArgsBuilder _argsBuilder;
private readonly WorkerConfig _cfg;
private readonly ILogger<TaskRunner> _logger;
public TaskRunner(
IClaudeProcess claude,
IDbContextFactory<ClaudeDoDbContext> dbFactory,
HubBroadcaster broadcaster,
WorktreeManager wtManager,
ClaudeArgsBuilder argsBuilder,
WorkerConfig cfg,
ILogger<TaskRunner> logger)
{
_claude = claude;
_dbFactory = dbFactory;
_broadcaster = broadcaster;
_wtManager = wtManager;
_argsBuilder = argsBuilder;
_cfg = cfg;
_logger = logger;
}
public async Task RunAsync(TaskEntity task, string slot, CancellationToken ct)
{
try
{
ListEntity? list;
ListConfigEntity? listConfig;
List<SubtaskEntity> subtasks;
using (var context = _dbFactory.CreateDbContext())
{
var listRepo = new ListRepository(context);
list = await listRepo.GetByIdAsync(task.ListId, ct);
if (list is null)
{
await MarkFailed(task.Id, task.Title, slot, "List not found.");
return;
}
listConfig = await listRepo.GetConfigAsync(task.ListId, ct);
var subtaskRepo = new SubtaskRepository(context);
subtasks = await subtaskRepo.GetByTaskIdAsync(task.Id, ct);
}
// Determine working directory: worktree or sandbox.
WorktreeContext? wtCtx = null;
string runDir;
if (list.WorkingDir is not null)
{
try
{
wtCtx = await _wtManager.CreateAsync(task, list, ct);
await _broadcaster.WorkerLog($"Created worktree for \"{task.Title}\"", WorkerLogLevel.Info, DateTime.UtcNow);
runDir = wtCtx.WorktreePath;
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to create worktree for task {TaskId}", task.Id);
await MarkFailed(task.Id, task.Title, slot, $"Worktree creation failed: {ex.Message}");
return;
}
}
else
{
runDir = Path.Combine(_cfg.SandboxRoot, task.Id);
Directory.CreateDirectory(runDir);
}
var resolvedConfig = await ResolveConfigAsync(task, listConfig, null, ct);
var now = DateTime.UtcNow;
using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
await taskRepo.MarkRunningAsync(task.Id, now, ct);
}
await _broadcaster.TaskStarted(slot, task.Id, now);
// Build prompt.
var sb = new System.Text.StringBuilder(task.Title);
if (!string.IsNullOrWhiteSpace(task.Description)) sb.Append("\n\n").Append(task.Description.Trim());
if (subtasks.Count > 0)
{
sb.Append("\n\n## Sub-Tasks\n");
foreach (var s in subtasks)
sb.Append(s.Completed ? "- [x] " : "- [ ] ").Append(s.Title).Append('\n');
}
var prompt = sb.ToString();
// Run 1.
var result = await RunOnceAsync(task.Id, task.Title, slot, runDir, resolvedConfig, 1, false, prompt, ct);
if (result.IsSuccess)
{
await HandleSuccess(task, list, slot, wtCtx, result, ct);
}
else
{
// Auto-retry: one attempt if we have a session ID.
if (result.SessionId is not null)
{
_logger.LogInformation("Auto-retrying task {TaskId} with session {SessionId}", task.Id, result.SessionId);
var retryConfig = resolvedConfig with { ResumeSessionId = result.SessionId };
var retryPrompt = $"The previous attempt failed with:\n\n{result.ErrorMarkdown}\n\nTry again and fix the issues.";
var retryResult = await RunOnceAsync(task.Id, task.Title, slot, runDir, retryConfig, 2, true, retryPrompt, ct);
if (retryResult.IsSuccess)
{
await HandleSuccess(task, list, slot, wtCtx, retryResult, ct);
}
else
{
await HandleFailure(task.Id, task.Title, slot, retryResult);
}
}
else
{
await HandleFailure(task.Id, task.Title, slot, result);
}
}
await _broadcaster.TaskUpdated(task.Id);
}
catch (OperationCanceledException)
{
_logger.LogInformation("Task {TaskId} was cancelled", task.Id);
await MarkFailed(task.Id, task.Title, slot, "Task cancelled.");
}
catch (Exception ex)
{
_logger.LogError(ex, "Unhandled exception running task {TaskId}", task.Id);
await MarkFailed(task.Id, task.Title, slot, $"Unhandled error: {ex.Message}");
}
}
public async Task ContinueAsync(string taskId, string followUpPrompt, string slot, CancellationToken ct)
{
TaskEntity task;
TaskRunEntity lastRun;
ListEntity list;
ListConfigEntity? listConfig;
WorktreeEntity? worktree;
using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
task = await taskRepo.GetByIdAsync(taskId, ct)
?? throw new KeyNotFoundException($"Task '{taskId}' not found.");
var runRepo = new TaskRunRepository(context);
lastRun = await runRepo.GetLatestByTaskIdAsync(taskId, ct)
?? throw new InvalidOperationException("No previous run to continue.");
if (lastRun.SessionId is null)
throw new InvalidOperationException("Previous run has no session ID — cannot resume.");
var listRepo = new ListRepository(context);
list = await listRepo.GetByIdAsync(task.ListId, ct)
?? throw new InvalidOperationException("List not found.");
listConfig = await listRepo.GetConfigAsync(task.ListId, ct);
var wtRepo = new WorktreeRepository(context);
worktree = await wtRepo.GetByTaskIdAsync(taskId, ct);
}
var resolvedConfig = await ResolveConfigAsync(task, listConfig, lastRun.SessionId, ct);
// Determine run directory from existing worktree or sandbox.
string runDir;
WorktreeContext? wtCtx = null;
if (worktree is not null)
{
runDir = worktree.Path;
wtCtx = new WorktreeContext(worktree.Path, worktree.BranchName, worktree.BaseCommit);
}
else
{
runDir = Path.Combine(_cfg.SandboxRoot, taskId);
}
var now = DateTime.UtcNow;
using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
await taskRepo.MarkRunningAsync(taskId, now, ct);
}
await _broadcaster.TaskStarted(slot, taskId, now);
var nextRunNumber = lastRun.RunNumber + 1;
var result = await RunOnceAsync(taskId, task.Title, slot, runDir, resolvedConfig, nextRunNumber, false, followUpPrompt, ct);
if (result.IsSuccess)
{
await HandleSuccess(task, list, slot, wtCtx, result, ct);
}
else
{
await HandleFailure(taskId, task.Title, slot, result);
}
await _broadcaster.TaskUpdated(taskId);
}
private async Task<RunResult> RunOnceAsync(
string taskId, string taskTitle, string slot, string runDir, ClaudeRunConfig config,
int runNumber, bool isRetry, string prompt, CancellationToken ct)
{
var runId = Guid.NewGuid().ToString();
var logPath = Path.Combine(_cfg.LogRoot, $"{taskId}_run{runNumber}.ndjson");
var run = new TaskRunEntity
{
Id = runId,
TaskId = taskId,
RunNumber = runNumber,
IsRetry = isRetry,
Prompt = prompt,
LogPath = logPath,
StartedAt = DateTime.UtcNow,
};
using (var context = _dbFactory.CreateDbContext())
{
var runRepo = new TaskRunRepository(context);
await runRepo.AddAsync(run, ct);
}
await _broadcaster.RunCreated(taskId, runNumber, isRetry);
var arguments = _argsBuilder.Build(config);
await using var logWriter = new LogWriter(logPath);
try
{
await _broadcaster.WorkerLog($"Started Claude for \"{taskTitle}\"", WorkerLogLevel.Info, DateTime.UtcNow);
var result = await _claude.RunAsync(
arguments,
prompt,
runDir,
async line =>
{
await logWriter.WriteLineAsync(line, ct);
await _broadcaster.TaskMessage(taskId, "[stdout] " + line);
},
ct);
// Update the run record with results. Use CancellationToken.None:
// this is a terminal write that must always complete, even if the
// caller's token is already cancelled.
run.SessionId = result.SessionId;
run.ResultMarkdown = result.ResultMarkdown;
run.StructuredOutputJson = result.StructuredOutputJson;
run.ErrorMarkdown = result.ErrorMarkdown;
run.ExitCode = result.ExitCode;
run.TurnCount = result.TurnCount;
run.TokensIn = result.TokensIn;
run.TokensOut = result.TokensOut;
run.FinishedAt = DateTime.UtcNow;
using (var context = _dbFactory.CreateDbContext())
{
var runRepo = new TaskRunRepository(context);
await runRepo.UpdateAsync(run, CancellationToken.None);
var taskRepo = new TaskRepository(context);
await taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None);
}
return result;
}
catch (OperationCanceledException)
{
// Ensure the run row is completed so ContinueAsync / inspection
// isn't left staring at a null session_id / finished_at.
run.ErrorMarkdown = "Cancelled.";
run.ExitCode = -1;
run.FinishedAt = DateTime.UtcNow;
try
{
using var context = _dbFactory.CreateDbContext();
var runRepo = new TaskRunRepository(context);
await runRepo.UpdateAsync(run, CancellationToken.None);
var taskRepo = new TaskRepository(context);
await taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None);
}
catch (Exception updateEx)
{
_logger.LogError(updateEx, "Failed to finalize cancelled run {RunId} for task {TaskId}", runId, taskId);
}
throw;
}
}
private async Task HandleSuccess(TaskEntity task, ListEntity list, string slot, WorktreeContext? wtCtx, RunResult result, CancellationToken ct)
{
if (wtCtx is not null)
{
var committed = await _wtManager.CommitIfChangedAsync(wtCtx, task, list, ct);
if (committed)
{
await _broadcaster.WorkerLog($"Committed changes in \"{task.Title}\"", WorkerLogLevel.Info, DateTime.UtcNow);
await _broadcaster.WorktreeUpdated(task.Id);
}
}
// Terminal DB write uses CancellationToken.None so the task status
// is never left as 'running' because of a cancel that arrived
// after the Claude run already succeeded.
var finishedAt = DateTime.UtcNow;
using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
await taskRepo.MarkDoneAsync(task.Id, finishedAt, result.ResultMarkdown, CancellationToken.None);
}
await _broadcaster.WorkerLog($"Finished \"{task.Title}\" (done)", WorkerLogLevel.Success, DateTime.UtcNow);
await _broadcaster.TaskFinished(slot, task.Id, "done", finishedAt);
_logger.LogInformation("Task {TaskId} completed (turns={Turns}, tokens_in={In}, tokens_out={Out})",
task.Id, result.TurnCount, result.TokensIn, result.TokensOut);
}
private async Task HandleFailure(string taskId, string taskTitle, string slot, RunResult result)
{
// Intentionally does not accept a CancellationToken: this is the
// terminal write for a failed task and must always be persisted.
var finishedAt = DateTime.UtcNow;
using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
await taskRepo.MarkFailedAsync(taskId, finishedAt, result.ErrorMarkdown, CancellationToken.None);
await _broadcaster.WorkerLog($"Finished \"{taskTitle}\" (failed)", WorkerLogLevel.Error, DateTime.UtcNow);
await _broadcaster.TaskFinished(slot, taskId, "failed", finishedAt);
_logger.LogWarning("Task {TaskId} failed (turns={Turns}): {Error}", taskId, result.TurnCount, result.ErrorMarkdown);
}
private async Task MarkFailed(string taskId, string taskTitle, string slot, string error)
{
try
{
var now = DateTime.UtcNow;
// Terminal write — never cancel.
using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
await taskRepo.MarkFailedAsync(taskId, now, error, CancellationToken.None);
await _broadcaster.WorkerLog($"Finished \"{taskTitle}\" (failed)", WorkerLogLevel.Error, DateTime.UtcNow);
await _broadcaster.TaskFinished(slot, taskId, "failed", now);
await _broadcaster.TaskUpdated(taskId);
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to mark task {TaskId} as failed", taskId);
}
}
private async Task<ClaudeRunConfig> ResolveConfigAsync(
TaskEntity task, ListConfigEntity? listConfig, string? resumeSessionId, CancellationToken ct)
{
AppSettingsEntity global;
using (var ctx = _dbFactory.CreateDbContext())
{
var settingsRepo = new AppSettingsRepository(ctx);
global = await settingsRepo.GetAsync(ct);
}
var instructions = MergeInstructions(
global.DefaultClaudeInstructions, listConfig?.SystemPrompt, task.SystemPrompt);
return new ClaudeRunConfig(
Model: task.Model ?? listConfig?.Model ?? global.DefaultModel,
SystemPrompt: string.IsNullOrWhiteSpace(instructions) ? null : instructions,
AgentPath: task.AgentPath ?? listConfig?.AgentPath,
ResumeSessionId: resumeSessionId,
MaxTurns: global.DefaultMaxTurns,
PermissionMode: global.DefaultPermissionMode);
}
public static string MergeInstructions(string? global, string? list, string? task)
{
var parts = new List<string>(3);
if (!string.IsNullOrWhiteSpace(global)) parts.Add(global.Trim());
if (!string.IsNullOrWhiteSpace(list)) parts.Add(list.Trim());
if (!string.IsNullOrWhiteSpace(task)) parts.Add(task.Trim());
return string.Join("\n\n", parts);
}
}