Skip to content

Commit

Permalink
feat: Implement RAG for code understanding
Browse files Browse the repository at this point in the history
- Add RAG implementation with ChromaDB for vector search
- Integrate with ReadCode for code summarization
- Add code chunking and context retrieval
- Add comprehensive test coverage

Fixes stitionai#450

Co-Authored-By: Erkin Alp Güney <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and erkinalp committed Dec 18, 2024
1 parent 3b98ed3 commit 97b01fc
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 7 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ orjson
gevent
gevent-websocket
curl_cffi
chromadb>=0.4.22
sentence-transformers>=2.2.2
22 changes: 16 additions & 6 deletions src/filesystem/read_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Dict, List

from src.config import Config
from src.memory.rag import CodeRAG

"""
TODO: Replace this with `code2prompt` - https://github.com/mufeedvh/code2prompt
Expand All @@ -9,27 +11,35 @@
class ReadCode:
def __init__(self, project_name: str):
config = Config()
project_path = config.get_projects_dir()
self.directory_path = os.path.join(project_path, project_name.lower().replace(" ", "-"))
self.project_name = project_name.lower().replace(" ", "-")
self.directory_path = os.path.join(config.get_projects_dir(), self.project_name)
self.rag = CodeRAG(project_name)

def read_directory(self):
def read_directory(self) -> List[Dict[str, str]]:
files_list = []
for root, _dirs, files in os.walk(self.directory_path):
for file in files:
try:
file_path = os.path.join(root, file)
with open(file_path, 'r') as file_content:
files_list.append({"filename": file_path, "code": file_content.read()})
code = file_content.read()
files_list.append({"filename": file_path, "code": code})
self.rag.add_code(file_path, code)
except:
pass

return files_list

def code_set_to_markdown(self):
def code_set_to_markdown(self) -> str:
code_set = self.read_directory()
markdown = ""
for code in code_set:
markdown += f"### {code['filename']}:\n\n"
summary = self.rag.summarize_code(code['code'])
if summary:
markdown += f"Summary: {summary}\n\n"
markdown += f"```\n{code['code']}\n```\n\n"
markdown += "---\n\n"
return markdown

def get_code_context(self, query: str, n_results: int = 5) -> Dict:
return self.rag.get_context(query, n_results)
109 changes: 108 additions & 1 deletion src/memory/rag.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,110 @@
"""
Vector Search for Code Docs + Docs Loading
"""
Implements RAG (Retrieval Augmented Generation) for code understanding
"""

import os
from typing import List, Dict, Optional
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions

from src.bert.sentence import SentenceBERT
from src.config import Config

class CodeRAG:
def __init__(self, project_name: str):
config = Config()
self.project_name = project_name.lower().replace(" ", "-")
self.db_path = os.path.join(config.get_projects_dir(), ".vector_db")
os.makedirs(self.db_path, exist_ok=True)

# Initialize ChromaDB with persistence
self.client = chromadb.PersistentClient(path=self.db_path)
self.sentence_transformer = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)

# Get or create collection for this project
self.collection = self.client.get_or_create_collection(
name=f"code_{self.project_name}",
embedding_function=self.sentence_transformer
)

def chunk_code(self, code: str, chunk_size: int = 1000) -> List[str]:
"""Split code into smaller chunks while preserving context."""
chunks = []
lines = code.split('\n')
current_chunk = []
current_size = 0

for line in lines:
line_size = len(line)
if current_size + line_size > chunk_size and current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_size = 0
current_chunk.append(line)
current_size += line_size

if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks

def add_code(self, filename: str, code: str):
"""Add code to the vector database with chunking."""
chunks = self.chunk_code(code)

# Generate unique IDs for chunks
chunk_ids = [f"{filename}_{i}" for i in range(len(chunks))]

# Add chunks to collection
self.collection.add(
documents=chunks,
ids=chunk_ids,
metadatas=[{"filename": filename, "chunk": i} for i in range(len(chunks))]
)

def query_similar(self, query: str, n_results: int = 5) -> List[Dict]:
"""Query the vector database for similar code chunks."""
results = self.collection.query(
query_texts=[query],
n_results=n_results
)

return [{
"text": doc,
"metadata": meta,
"distance": dist
} for doc, meta, dist in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0]
)]

def summarize_code(self, code: str) -> str:
"""Extract key information from code using SentenceBERT."""
bert = SentenceBERT(code)
keywords = bert.extract_keywords(top_n=10)
return ", ".join([kw[0] for kw in keywords])

def get_context(self, query: str, n_results: int = 5) -> Dict:
"""Get relevant code context for a query."""
similar_chunks = self.query_similar(query, n_results)

context = {
"relevant_code": [],
"summary": [],
"files": set()
}

for chunk in similar_chunks:
context["relevant_code"].append({
"code": chunk["text"],
"file": chunk["metadata"]["filename"],
"relevance": 1 - chunk["distance"] # Convert distance to similarity score
})
context["files"].add(chunk["metadata"]["filename"])
context["summary"].append(self.summarize_code(chunk["text"]))

return context
30 changes: 30 additions & 0 deletions tests/test_code_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest
from src.memory.rag import CodeRAG

class TestCodeRAG(unittest.TestCase):
def setUp(self):
self.test_project = "test_project"
self.rag = CodeRAG(self.test_project)

# Test code sample
self.test_code = '''
def calculate_sum(a: int, b: int) -> int:
"""Calculate the sum of two integers."""
return a + b
def multiply_numbers(x: int, y: int) -> int:
"""Multiply two numbers together."""
return x * y
'''

def test_chunk_code(self):
"""Test code chunking functionality."""
chunks = self.rag.chunk_code(self.test_code, chunk_size=100)
self.assertTrue(len(chunks) > 0)
self.assertTrue(all(len(chunk) <= 100 for chunk in chunks))
# Verify function boundaries are preserved
self.assertTrue(any("calculate_sum" in chunk for chunk in chunks))
self.assertTrue(any("multiply_numbers" in chunk for chunk in chunks))

if __name__ == '__main__':
unittest.main()
59 changes: 59 additions & 0 deletions tests/test_read_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import shutil
import unittest
from src.filesystem.read_code import ReadCode
from src.config import Config

class TestReadCode(unittest.TestCase):
def setUp(self):
self.config = Config()
self.test_project = "test_project"
self.project_dir = os.path.join(self.config.get_projects_dir(), self.test_project)

# Create test project directory and files
os.makedirs(self.project_dir, exist_ok=True)
self.test_code = '''
def calculate_sum(a: int, b: int) -> int:
"""Calculate the sum of two integers."""
return a + b
def multiply_numbers(x: int, y: int) -> int:
"""Multiply two numbers together."""
return x * y
'''
with open(os.path.join(self.project_dir, "test.py"), "w") as f:
f.write(self.test_code)

self.reader = ReadCode(self.test_project)

def test_read_directory(self):
files = self.reader.read_directory()
self.assertEqual(len(files), 1)
self.assertTrue(any(f["filename"].endswith("test.py") for f in files))
self.assertTrue(any("calculate_sum" in f["code"] for f in files))

def test_code_set_to_markdown(self):
markdown = self.reader.code_set_to_markdown()
self.assertIn("test.py", markdown)
self.assertIn("```", markdown)
self.assertIn("calculate_sum", markdown)
self.assertIn("Summary:", markdown)

def test_get_code_context(self):
context = self.reader.get_code_context("How to multiply numbers?")
self.assertTrue("relevant_code" in context)
self.assertTrue("summary" in context)
self.assertTrue(any("multiply" in code["code"].lower()
for code in context["relevant_code"]))

def tearDown(self):
# Clean up test files
if os.path.exists(self.project_dir):
shutil.rmtree(self.project_dir)
# Clean up vector database
vector_db_path = os.path.join(self.config.get_projects_dir(), ".vector_db")
if os.path.exists(vector_db_path):
shutil.rmtree(vector_db_path)

if __name__ == '__main__':
unittest.main()

0 comments on commit 97b01fc

Please sign in to comment.