forked from stitionai/devika
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement RAG for code understanding
- Add RAG implementation with ChromaDB for vector search - Integrate with ReadCode for code summarization - Add code chunking and context retrieval - Add comprehensive test coverage Fixes stitionai#450 Co-Authored-By: Erkin Alp Güney <[email protected]>
- Loading branch information
1 parent
3b98ed3
commit 97b01fc
Showing
5 changed files
with
215 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,3 +31,5 @@ orjson | |
gevent | ||
gevent-websocket | ||
curl_cffi | ||
chromadb>=0.4.22 | ||
sentence-transformers>=2.2.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,110 @@ | ||
""" | ||
Vector Search for Code Docs + Docs Loading | ||
""" | ||
Implements RAG (Retrieval Augmented Generation) for code understanding | ||
""" | ||
|
||
import os | ||
from typing import List, Dict, Optional | ||
import chromadb | ||
from chromadb.config import Settings | ||
from chromadb.utils import embedding_functions | ||
|
||
from src.bert.sentence import SentenceBERT | ||
from src.config import Config | ||
|
||
class CodeRAG: | ||
def __init__(self, project_name: str): | ||
config = Config() | ||
self.project_name = project_name.lower().replace(" ", "-") | ||
self.db_path = os.path.join(config.get_projects_dir(), ".vector_db") | ||
os.makedirs(self.db_path, exist_ok=True) | ||
|
||
# Initialize ChromaDB with persistence | ||
self.client = chromadb.PersistentClient(path=self.db_path) | ||
self.sentence_transformer = embedding_functions.SentenceTransformerEmbeddingFunction( | ||
model_name="all-MiniLM-L6-v2" | ||
) | ||
|
||
# Get or create collection for this project | ||
self.collection = self.client.get_or_create_collection( | ||
name=f"code_{self.project_name}", | ||
embedding_function=self.sentence_transformer | ||
) | ||
|
||
def chunk_code(self, code: str, chunk_size: int = 1000) -> List[str]: | ||
"""Split code into smaller chunks while preserving context.""" | ||
chunks = [] | ||
lines = code.split('\n') | ||
current_chunk = [] | ||
current_size = 0 | ||
|
||
for line in lines: | ||
line_size = len(line) | ||
if current_size + line_size > chunk_size and current_chunk: | ||
chunks.append('\n'.join(current_chunk)) | ||
current_chunk = [] | ||
current_size = 0 | ||
current_chunk.append(line) | ||
current_size += line_size | ||
|
||
if current_chunk: | ||
chunks.append('\n'.join(current_chunk)) | ||
return chunks | ||
|
||
def add_code(self, filename: str, code: str): | ||
"""Add code to the vector database with chunking.""" | ||
chunks = self.chunk_code(code) | ||
|
||
# Generate unique IDs for chunks | ||
chunk_ids = [f"{filename}_{i}" for i in range(len(chunks))] | ||
|
||
# Add chunks to collection | ||
self.collection.add( | ||
documents=chunks, | ||
ids=chunk_ids, | ||
metadatas=[{"filename": filename, "chunk": i} for i in range(len(chunks))] | ||
) | ||
|
||
def query_similar(self, query: str, n_results: int = 5) -> List[Dict]: | ||
"""Query the vector database for similar code chunks.""" | ||
results = self.collection.query( | ||
query_texts=[query], | ||
n_results=n_results | ||
) | ||
|
||
return [{ | ||
"text": doc, | ||
"metadata": meta, | ||
"distance": dist | ||
} for doc, meta, dist in zip( | ||
results["documents"][0], | ||
results["metadatas"][0], | ||
results["distances"][0] | ||
)] | ||
|
||
def summarize_code(self, code: str) -> str: | ||
"""Extract key information from code using SentenceBERT.""" | ||
bert = SentenceBERT(code) | ||
keywords = bert.extract_keywords(top_n=10) | ||
return ", ".join([kw[0] for kw in keywords]) | ||
|
||
def get_context(self, query: str, n_results: int = 5) -> Dict: | ||
"""Get relevant code context for a query.""" | ||
similar_chunks = self.query_similar(query, n_results) | ||
|
||
context = { | ||
"relevant_code": [], | ||
"summary": [], | ||
"files": set() | ||
} | ||
|
||
for chunk in similar_chunks: | ||
context["relevant_code"].append({ | ||
"code": chunk["text"], | ||
"file": chunk["metadata"]["filename"], | ||
"relevance": 1 - chunk["distance"] # Convert distance to similarity score | ||
}) | ||
context["files"].add(chunk["metadata"]["filename"]) | ||
context["summary"].append(self.summarize_code(chunk["text"])) | ||
|
||
return context |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import unittest | ||
from src.memory.rag import CodeRAG | ||
|
||
class TestCodeRAG(unittest.TestCase): | ||
def setUp(self): | ||
self.test_project = "test_project" | ||
self.rag = CodeRAG(self.test_project) | ||
|
||
# Test code sample | ||
self.test_code = ''' | ||
def calculate_sum(a: int, b: int) -> int: | ||
"""Calculate the sum of two integers.""" | ||
return a + b | ||
def multiply_numbers(x: int, y: int) -> int: | ||
"""Multiply two numbers together.""" | ||
return x * y | ||
''' | ||
|
||
def test_chunk_code(self): | ||
"""Test code chunking functionality.""" | ||
chunks = self.rag.chunk_code(self.test_code, chunk_size=100) | ||
self.assertTrue(len(chunks) > 0) | ||
self.assertTrue(all(len(chunk) <= 100 for chunk in chunks)) | ||
# Verify function boundaries are preserved | ||
self.assertTrue(any("calculate_sum" in chunk for chunk in chunks)) | ||
self.assertTrue(any("multiply_numbers" in chunk for chunk in chunks)) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
import shutil | ||
import unittest | ||
from src.filesystem.read_code import ReadCode | ||
from src.config import Config | ||
|
||
class TestReadCode(unittest.TestCase): | ||
def setUp(self): | ||
self.config = Config() | ||
self.test_project = "test_project" | ||
self.project_dir = os.path.join(self.config.get_projects_dir(), self.test_project) | ||
|
||
# Create test project directory and files | ||
os.makedirs(self.project_dir, exist_ok=True) | ||
self.test_code = ''' | ||
def calculate_sum(a: int, b: int) -> int: | ||
"""Calculate the sum of two integers.""" | ||
return a + b | ||
def multiply_numbers(x: int, y: int) -> int: | ||
"""Multiply two numbers together.""" | ||
return x * y | ||
''' | ||
with open(os.path.join(self.project_dir, "test.py"), "w") as f: | ||
f.write(self.test_code) | ||
|
||
self.reader = ReadCode(self.test_project) | ||
|
||
def test_read_directory(self): | ||
files = self.reader.read_directory() | ||
self.assertEqual(len(files), 1) | ||
self.assertTrue(any(f["filename"].endswith("test.py") for f in files)) | ||
self.assertTrue(any("calculate_sum" in f["code"] for f in files)) | ||
|
||
def test_code_set_to_markdown(self): | ||
markdown = self.reader.code_set_to_markdown() | ||
self.assertIn("test.py", markdown) | ||
self.assertIn("```", markdown) | ||
self.assertIn("calculate_sum", markdown) | ||
self.assertIn("Summary:", markdown) | ||
|
||
def test_get_code_context(self): | ||
context = self.reader.get_code_context("How to multiply numbers?") | ||
self.assertTrue("relevant_code" in context) | ||
self.assertTrue("summary" in context) | ||
self.assertTrue(any("multiply" in code["code"].lower() | ||
for code in context["relevant_code"])) | ||
|
||
def tearDown(self): | ||
# Clean up test files | ||
if os.path.exists(self.project_dir): | ||
shutil.rmtree(self.project_dir) | ||
# Clean up vector database | ||
vector_db_path = os.path.join(self.config.get_projects_dir(), ".vector_db") | ||
if os.path.exists(vector_db_path): | ||
shutil.rmtree(vector_db_path) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |