Skip to content

Commit

Permalink
fix: replace tool types with openai tool choice
Browse files Browse the repository at this point in the history
  • Loading branch information
sshivaditya committed Jan 10, 2025
1 parent b878307 commit 239e201
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 68 deletions.
70 changes: 39 additions & 31 deletions src/adapters/openai/helpers/completions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Available Tools:
### ReadFile Tool ###
- Purpose: Read file contents
- Method: execute(filename: string)
- Method: execute(args: { filename: string })
- Returns: ToolResult<FileReadResult> containing:
- success: boolean
- data: { content: string, path: string }
Expand All @@ -80,7 +80,7 @@ Available Tools:
### WriteFile Tool ###
- Purpose: Update file contents using diff blocks
- Method: execute(path: string, diff: string)
- Method: execute(args: { filename: string, content: string })
- Requires absolute file paths (must start with '/')
- Diff format:
Expand All @@ -98,7 +98,7 @@ Available Tools:
### ExploreDir Tool ###
- Purpose: Directory operations
- Method: execute(command: 'tree', args?: any)
- Method: execute(args: { command: 'tree' | 'change-dir' | 'clone' | 'kill', dir?: string, repo?: string, owner?: string, issueNumber?: number })
- Returns: ToolResult<DirectoryExploreResult> containing:
- success: boolean
- data: { currentPath: string, tree?: string }
Expand All @@ -107,7 +107,7 @@ Available Tools:
### SearchFiles Tool ###
- Purpose: Search files using regex patterns
- Method: execute(pattern: string, options?: { filePattern?: string, caseSensitive?: boolean, contextLines?: number })
- Method: execute(args: { pattern: string, filePattern?: string, caseSensitive?: boolean, contextLines?: number })
- Returns: ToolResult<SearchResult> containing:
- success: boolean
- data: {
Expand Down Expand Up @@ -142,15 +142,7 @@ type ToolName = keyof ToolResultMap;

interface ToolRequest {
tool: ToolName;
args: {
filename?: string;
content?: string;
command?: "tree";
pattern?: string;
filePattern?: string;
caseSensitive?: boolean;
contextLines?: number;
};
args: Record<string, unknown>;
}

type ChatMessage = {
Expand Down Expand Up @@ -184,23 +176,23 @@ export class Completions extends SuperOpenAi {
switch (request.tool) {
case "readFile":
if (!request.args.filename) throw new Error("Filename is required for readFile");
return this._readFile(request.args.filename, workingDir);
return this._readFile(request.args.filename as string, workingDir);

case "writeFile":
if (!request.args.filename || !request.args.content) {
throw new Error("Filename and content are required for writeFile");
}
return this._writeFile(request.args.filename, request.args.content, workingDir);
return this._writeFile(request.args.filename as string, request.args.content as string, workingDir);

case "exploreDir":
return this._getDirectoryTree(workingDir);

case "searchFiles":
if (!request.args.pattern) throw new Error("Search pattern is required");
return this._searchFiles(request.args.pattern, workingDir, {
filePattern: request.args.filePattern,
caseSensitive: request.args.caseSensitive,
contextLines: request.args.contextLines,
return this._searchFiles(request.args.pattern as string, workingDir, {
filePattern: request.args.filePattern as string,
caseSensitive: request.args.caseSensitive as boolean,
contextLines: request.args.contextLines as number,
});

default:
Expand All @@ -226,9 +218,22 @@ export class Completions extends SuperOpenAi {
const toolJson = toolBlock[1];

try {
console.log(toolJson);
this.context.logger.info(`Processing tool request:`, { toolJson });
const toolRequest: ToolRequest = JSON.parse(toolJson);
// Trim any whitespace and ensure we have valid JSON
const trimmedJson = toolJson.trim();
if (!trimmedJson.endsWith("}")) {
throw new Error("Malformed JSON: Missing closing brace");
}

this.context.logger.info(`Processing tool request:`, { toolJson: trimmedJson });
const toolRequest: ToolRequest = JSON.parse(trimmedJson);

// Validate required fields
if (!toolRequest.tool) {
throw new Error('Tool request missing required "tool" field');
}
if (!toolRequest.args) {
throw new Error('Tool request missing required "args" field');
}

// For writeFile, ensure content is stringified if it's an object
if (toolRequest.tool === "writeFile" && toolRequest.args.content && typeof toolRequest.args.content === "object") {
Expand Down Expand Up @@ -289,7 +294,7 @@ export class Completions extends SuperOpenAi {
tool: Tool<ToolResultMap[T]>,
method: string,
workingDir: string,
...args: unknown[]
args: Record<string, unknown>
): Promise<ToolResult<ToolResultMap[T]>> {
this.toolAttempts++;

Expand All @@ -309,12 +314,12 @@ export class Completions extends SuperOpenAi {
}

try {
const result = await tool.execute(...args);
const result = await tool.execute(args);

if (!result.success && this.toolAttempts < MAX_TRIES) {
const error = new Error(result.error || "Unknown error");
this.context.logger.error(`Tool attempt ${this.toolAttempts} failed:`, { error });
return this._executeWithRetry(tool, method, workingDir, ...args);
return this._executeWithRetry(tool, method, workingDir, args);
}

if (result.success) {
Expand All @@ -331,7 +336,7 @@ export class Completions extends SuperOpenAi {
this.context.logger.error(`Tool attempt ${this.toolAttempts} error:`, { error: errorObj });

if (this.toolAttempts < MAX_TRIES) {
return this._executeWithRetry(tool, method, workingDir, ...args);
return this._executeWithRetry(tool, method, workingDir, args);
}

return {
Expand Down Expand Up @@ -446,20 +451,20 @@ ${currentSolution}`;
}

private async _createPullRequest(title: string, body: string) {
return this._executeWithRetry(this.tools.createPr, "execute", "", title, body);
return this._executeWithRetry(this.tools.createPr, "execute", "", { title, body });
}

// Helper methods to execute tools with retry logic
private async _readFile(filename: string, workingDir: string) {
return this._executeWithRetry(this.tools.readFile, "execute", workingDir, filename);
return this._executeWithRetry(this.tools.readFile, "execute", workingDir, { filename });
}

private async _writeFile(filename: string, content: string, workingDir: string) {
return this._executeWithRetry(this.tools.writeFile, "execute", workingDir, filename, content);
return this._executeWithRetry(this.tools.writeFile, "execute", workingDir, { filename, content });
}

private async _getDirectoryTree(workingDir: string) {
return this._executeWithRetry(this.tools.exploreDir, "execute", workingDir, "tree");
return this._executeWithRetry(this.tools.exploreDir, "execute", workingDir, { command: "tree" });
}

private async _searchFiles(
Expand All @@ -471,6 +476,9 @@ ${currentSolution}`;
contextLines?: number;
}
) {
return this._executeWithRetry(this.tools.searchFiles, "execute", workingDir, pattern, options);
return this._executeWithRetry(this.tools.searchFiles, "execute", workingDir, {
pattern,
...options,
});
}
}
11 changes: 8 additions & 3 deletions src/handlers/front-controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ export async function delegate(context: Context) {

try {
// First clone the repository
const cloneResult = await explore.execute("clone", { repo, owner, issueNumber });
const cloneResult = await explore.execute({
command: "clone",
repo,
owner,
issueNumber,
});
if (!cloneResult.success || !cloneResult.data) {
logger.error(`Failed to clone repository: ${cloneResult.error}`);
return;
Expand All @@ -25,7 +30,7 @@ export async function delegate(context: Context) {
const workingDir = cloneResult.data.currentPath;

// Get the directory tree for context
const treeResult = await explore.execute("tree");
const treeResult = await explore.execute({ command: "tree" });
const fileTree = treeResult.success && treeResult.data?.tree ? treeResult.data.tree : "";

// Start the completion process with the issue description and file tree
Expand Down Expand Up @@ -59,7 +64,7 @@ export async function delegate(context: Context) {
});

// Cleanup
await explore.execute("kill");
await explore.execute({ command: "kill" });
} catch (error) {
logger.error(`Error during completion: ${error instanceof Error ? error.message : "Unknown error"}`);

Expand Down
30 changes: 26 additions & 4 deletions src/tools/create-pr/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Tool, ToolResult } from "../../types/tool";
import { Tool, ToolResult, FunctionParameters } from "../../types/tool";
import { Context } from "../../types/context";
import { execSync } from "child_process";

Expand All @@ -8,17 +8,39 @@ export interface PullRequestResult {
title: string;
}

export class CreatePr implements Tool {
readonly name = "create-pr";
export class CreatePr implements Tool<PullRequestResult> {
readonly name = "createPr";
readonly description = "Creates a pull request with the changes";
readonly parameters: FunctionParameters = {
type: "object",
properties: {
title: {
type: "string",
description: "Title of the pull request",
},
body: {
type: "string",
description: "Description/body of the pull request",
},
},
required: ["title", "body"],
};

private _context: Context;

constructor(context: Context) {
this._context = context;
}

async execute(title: string, body: string): Promise<ToolResult<PullRequestResult>> {
async execute(args: Record<string, unknown>): Promise<ToolResult<PullRequestResult>> {
try {
const title = args.title as string;
const body = args.body as string;

if (!title || !body) {
throw new Error("Title and body are required");
}

try {
// Stage all changes
this._context.logger.info("Staging changes");
Expand Down
39 changes: 34 additions & 5 deletions src/tools/explore-dir/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
import { Terminal } from "../terminal";
import { Tool, ToolResult, DirectoryExploreResult } from "../../types/tool";
import { Tool, ToolResult, DirectoryExploreResult, FunctionParameters } from "../../types/tool";

export class ExploreDir implements Tool<DirectoryExploreResult> {
readonly name = "explore-dir";
readonly name = "exploreDir";
readonly description = "Explores and manipulates directories, including git operations";
readonly parameters: FunctionParameters = {
type: "object",
properties: {
command: {
type: "string",
description: "Command to execute",
enum: ["change-dir", "tree", "clone", "kill"],
},
dir: {
type: "string",
description: "Directory path for change-dir command",
},
repo: {
type: "string",
description: "Repository name for clone command",
},
owner: {
type: "string",
description: "Repository owner for clone command",
},
issueNumber: {
type: "number",
description: "Issue number for clone command",
},
},
required: ["command"],
};

private _shellInterface: Terminal;
private _currentDir: string;
Expand All @@ -14,11 +41,13 @@ export class ExploreDir implements Tool<DirectoryExploreResult> {
this._currentDir = workDir;
}

async execute(command: "change-dir" | "tree" | "clone" | "kill", args?: Record<string, unknown>): Promise<ToolResult<DirectoryExploreResult>> {
async execute(args: Record<string, unknown>): Promise<ToolResult<DirectoryExploreResult>> {
const command = args.command as "change-dir" | "tree" | "clone" | "kill";

try {
switch (command) {
case "change-dir": {
const dir = args?.dir;
const dir = args.dir;
if (typeof dir !== "string") {
throw new Error("Directory path must be a string");
}
Expand All @@ -40,7 +69,7 @@ export class ExploreDir implements Tool<DirectoryExploreResult> {
};
}
case "clone": {
const { repo, owner, issueNumber } = args || {};
const { repo, owner, issueNumber } = args;
if (!repo || !owner || !issueNumber || typeof repo !== "string" || typeof owner !== "string" || typeof issueNumber !== "number") {
throw new Error("Missing required clone arguments");
}
Expand Down
25 changes: 20 additions & 5 deletions src/tools/read-file/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
import { readFileSync } from "fs";
import { Tool, ToolResult, FileReadResult } from "../../types/tool";
import { Tool, ToolResult, FileReadResult, FunctionParameters } from "../../types/tool";

export class ReadFile implements Tool {
readonly name = "read-file";
export class ReadFile implements Tool<FileReadResult> {
readonly name = "readFile";
readonly description = "Reads content from a file at the specified path";
readonly parameters: FunctionParameters = {
type: "object",
properties: {
filename: {
type: "string",
description: "Absolute path to the file",
},
},
required: ["filename"],
};

async execute(path: string): Promise<ToolResult<FileReadResult>> {
async execute(args: Record<string, unknown>): Promise<ToolResult<FileReadResult>> {
const path = args.filename as string;
try {
if (!path) {
throw new Error("Filename is required");
}

console.log(`Reading file: ${path}`);
const content = readFileSync(path, "utf8");

Expand Down Expand Up @@ -35,7 +50,7 @@ export class ReadFile implements Tool {

async batchRead(paths: string[]): Promise<ToolResult<FileReadResult[]>> {
try {
const results = await Promise.all(paths.map((path) => this.execute(path)));
const results = await Promise.all(paths.map((path) => this.execute({ filename: path })));
const successfulReads = results.filter((result) => result.success && result.data).map((result) => result.data as FileReadResult);

return {
Expand Down
Loading

0 comments on commit 239e201

Please sign in to comment.