From 97b01fc854de7876044a00ab4b81ea433fc85236 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:20:42 +0000 Subject: [PATCH] feat: Implement RAG for code understanding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add RAG implementation with ChromaDB for vector search - Integrate with ReadCode for code summarization - Add code chunking and context retrieval - Add comprehensive test coverage Fixes #450 Co-Authored-By: Erkin Alp Güney --- requirements.txt | 2 + src/filesystem/read_code.py | 22 ++++++-- src/memory/rag.py | 109 +++++++++++++++++++++++++++++++++++- tests/test_code_rag.py | 30 ++++++++++ tests/test_read_code.py | 59 +++++++++++++++++++ 5 files changed, 215 insertions(+), 7 deletions(-) create mode 100644 tests/test_code_rag.py create mode 100644 tests/test_read_code.py diff --git a/requirements.txt b/requirements.txt index 91666960..e78ef600 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,5 @@ orjson gevent gevent-websocket curl_cffi +chromadb>=0.4.22 +sentence-transformers>=2.2.2 diff --git a/src/filesystem/read_code.py b/src/filesystem/read_code.py index 71b76f7f..311999f5 100644 --- a/src/filesystem/read_code.py +++ b/src/filesystem/read_code.py @@ -1,6 +1,8 @@ import os +from typing import Dict, List from src.config import Config +from src.memory.rag import CodeRAG """ TODO: Replace this with `code2prompt` - https://github.com/mufeedvh/code2prompt @@ -9,27 +11,35 @@ class ReadCode: def __init__(self, project_name: str): config = Config() - project_path = config.get_projects_dir() - self.directory_path = os.path.join(project_path, project_name.lower().replace(" ", "-")) + self.project_name = project_name.lower().replace(" ", "-") + self.directory_path = os.path.join(config.get_projects_dir(), self.project_name) + self.rag = CodeRAG(project_name) - def read_directory(self): + def read_directory(self) -> List[Dict[str, str]]: files_list = [] for root, _dirs, files in os.walk(self.directory_path): for file in files: try: file_path = os.path.join(root, file) with open(file_path, 'r') as file_content: - files_list.append({"filename": file_path, "code": file_content.read()}) + code = file_content.read() + files_list.append({"filename": file_path, "code": code}) + self.rag.add_code(file_path, code) except: pass - return files_list - def code_set_to_markdown(self): + def code_set_to_markdown(self) -> str: code_set = self.read_directory() markdown = "" for code in code_set: markdown += f"### {code['filename']}:\n\n" + summary = self.rag.summarize_code(code['code']) + if summary: + markdown += f"Summary: {summary}\n\n" markdown += f"```\n{code['code']}\n```\n\n" markdown += "---\n\n" return markdown + + def get_code_context(self, query: str, n_results: int = 5) -> Dict: + return self.rag.get_context(query, n_results) diff --git a/src/memory/rag.py b/src/memory/rag.py index 8229bf93..b0107908 100644 --- a/src/memory/rag.py +++ b/src/memory/rag.py @@ -1,3 +1,110 @@ """ Vector Search for Code Docs + Docs Loading -""" \ No newline at end of file +Implements RAG (Retrieval Augmented Generation) for code understanding +""" + +import os +from typing import List, Dict, Optional +import chromadb +from chromadb.config import Settings +from chromadb.utils import embedding_functions + +from src.bert.sentence import SentenceBERT +from src.config import Config + +class CodeRAG: + def __init__(self, project_name: str): + config = Config() + self.project_name = project_name.lower().replace(" ", "-") + self.db_path = os.path.join(config.get_projects_dir(), ".vector_db") + os.makedirs(self.db_path, exist_ok=True) + + # Initialize ChromaDB with persistence + self.client = chromadb.PersistentClient(path=self.db_path) + self.sentence_transformer = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name="all-MiniLM-L6-v2" + ) + + # Get or create collection for this project + self.collection = self.client.get_or_create_collection( + name=f"code_{self.project_name}", + embedding_function=self.sentence_transformer + ) + + def chunk_code(self, code: str, chunk_size: int = 1000) -> List[str]: + """Split code into smaller chunks while preserving context.""" + chunks = [] + lines = code.split('\n') + current_chunk = [] + current_size = 0 + + for line in lines: + line_size = len(line) + if current_size + line_size > chunk_size and current_chunk: + chunks.append('\n'.join(current_chunk)) + current_chunk = [] + current_size = 0 + current_chunk.append(line) + current_size += line_size + + if current_chunk: + chunks.append('\n'.join(current_chunk)) + return chunks + + def add_code(self, filename: str, code: str): + """Add code to the vector database with chunking.""" + chunks = self.chunk_code(code) + + # Generate unique IDs for chunks + chunk_ids = [f"{filename}_{i}" for i in range(len(chunks))] + + # Add chunks to collection + self.collection.add( + documents=chunks, + ids=chunk_ids, + metadatas=[{"filename": filename, "chunk": i} for i in range(len(chunks))] + ) + + def query_similar(self, query: str, n_results: int = 5) -> List[Dict]: + """Query the vector database for similar code chunks.""" + results = self.collection.query( + query_texts=[query], + n_results=n_results + ) + + return [{ + "text": doc, + "metadata": meta, + "distance": dist + } for doc, meta, dist in zip( + results["documents"][0], + results["metadatas"][0], + results["distances"][0] + )] + + def summarize_code(self, code: str) -> str: + """Extract key information from code using SentenceBERT.""" + bert = SentenceBERT(code) + keywords = bert.extract_keywords(top_n=10) + return ", ".join([kw[0] for kw in keywords]) + + def get_context(self, query: str, n_results: int = 5) -> Dict: + """Get relevant code context for a query.""" + similar_chunks = self.query_similar(query, n_results) + + context = { + "relevant_code": [], + "summary": [], + "files": set() + } + + for chunk in similar_chunks: + context["relevant_code"].append({ + "code": chunk["text"], + "file": chunk["metadata"]["filename"], + "relevance": 1 - chunk["distance"] # Convert distance to similarity score + }) + context["files"].add(chunk["metadata"]["filename"]) + context["summary"].append(self.summarize_code(chunk["text"])) + + return context diff --git a/tests/test_code_rag.py b/tests/test_code_rag.py new file mode 100644 index 00000000..0f63f17c --- /dev/null +++ b/tests/test_code_rag.py @@ -0,0 +1,30 @@ +import unittest +from src.memory.rag import CodeRAG + +class TestCodeRAG(unittest.TestCase): + def setUp(self): + self.test_project = "test_project" + self.rag = CodeRAG(self.test_project) + + # Test code sample + self.test_code = ''' +def calculate_sum(a: int, b: int) -> int: + """Calculate the sum of two integers.""" + return a + b + +def multiply_numbers(x: int, y: int) -> int: + """Multiply two numbers together.""" + return x * y +''' + + def test_chunk_code(self): + """Test code chunking functionality.""" + chunks = self.rag.chunk_code(self.test_code, chunk_size=100) + self.assertTrue(len(chunks) > 0) + self.assertTrue(all(len(chunk) <= 100 for chunk in chunks)) + # Verify function boundaries are preserved + self.assertTrue(any("calculate_sum" in chunk for chunk in chunks)) + self.assertTrue(any("multiply_numbers" in chunk for chunk in chunks)) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_read_code.py b/tests/test_read_code.py new file mode 100644 index 00000000..5e84c20e --- /dev/null +++ b/tests/test_read_code.py @@ -0,0 +1,59 @@ +import os +import shutil +import unittest +from src.filesystem.read_code import ReadCode +from src.config import Config + +class TestReadCode(unittest.TestCase): + def setUp(self): + self.config = Config() + self.test_project = "test_project" + self.project_dir = os.path.join(self.config.get_projects_dir(), self.test_project) + + # Create test project directory and files + os.makedirs(self.project_dir, exist_ok=True) + self.test_code = ''' +def calculate_sum(a: int, b: int) -> int: + """Calculate the sum of two integers.""" + return a + b + +def multiply_numbers(x: int, y: int) -> int: + """Multiply two numbers together.""" + return x * y +''' + with open(os.path.join(self.project_dir, "test.py"), "w") as f: + f.write(self.test_code) + + self.reader = ReadCode(self.test_project) + + def test_read_directory(self): + files = self.reader.read_directory() + self.assertEqual(len(files), 1) + self.assertTrue(any(f["filename"].endswith("test.py") for f in files)) + self.assertTrue(any("calculate_sum" in f["code"] for f in files)) + + def test_code_set_to_markdown(self): + markdown = self.reader.code_set_to_markdown() + self.assertIn("test.py", markdown) + self.assertIn("```", markdown) + self.assertIn("calculate_sum", markdown) + self.assertIn("Summary:", markdown) + + def test_get_code_context(self): + context = self.reader.get_code_context("How to multiply numbers?") + self.assertTrue("relevant_code" in context) + self.assertTrue("summary" in context) + self.assertTrue(any("multiply" in code["code"].lower() + for code in context["relevant_code"])) + + def tearDown(self): + # Clean up test files + if os.path.exists(self.project_dir): + shutil.rmtree(self.project_dir) + # Clean up vector database + vector_db_path = os.path.join(self.config.get_projects_dir(), ".vector_db") + if os.path.exists(vector_db_path): + shutil.rmtree(vector_db_path) + +if __name__ == '__main__': + unittest.main()