Motivation - Why RAG?
In the world of AI-assisted programming, Retrieval-Augmented Generation (RAG) is reshaping how developers interact with language models (LLMs) for code generation tasks. RAG combines retrieval and generation models to produce more accurate outputs tailored to specific tasks or projects. Instead of overwhelming LLMs with entire codebases, RAG strategically breaks down code into manageable segments. It then constructs targeted prompts with relevant context, improving the quality of generated code. For instance, when implementing functionality akin to an existing project API, RAG identifies relevant code snippets and incorporates them into a prompt. This focused approach allows LLMs to generate code that aligns closely with project requirements. RAG offers scalability and adaptability, making it a versatile tool for developers tackling diverse projects. By leveraging existing knowledge and context, RAG streamlines code generation, paving the way for more efficient programming workflows. And YES, if we can input long text as a prompt, there might be no need for RAG anymore since we can give it to LLM, and in the future, it might be possible to do that without RAG. But for now, RAG is so useful. Here in this blog, we will delve into creating a code assistant (like Copilot) by utilizing open-source LLMs like zephyr-7b-beta and take advantage of LangChain to build our RAG. We will discuss in detail the modules involved in RAG to demonstrate a proven Copilot for your own code, achievable with just a 16 GB GPU.
Dependencies
So let's say we have a codebase that we want to create RAG above. For parsing your code, there are parsers in LangChain. We'll use Python as an example, but you can select your preferred language. At the time this blog is written, there is an issue with the LangChain-community library for the module LanguageParser
, and these languages are not yet added to enum.
LanguageParser
issue, you can run the code below:
from langchain.text_splitter import Language
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers import LanguageParser
This code snippet imports the necessary modules from LangChain and LangChain-community. If there are no errors upon running this code, then the LanguageParser
issue should be resolved.
Parse Your Repo
Assuming you have a repository folder structured like this:
Firstly, we should load all files properly, meaning that we maintain the code structure as it is, such as keeping entire classes or function definitions as one document and being respectful to different levels of code. For this purpose, we can use LanguageParser
. For Python, you can use your repository's code language.
# Load
loader = GenericLoader.from_filesystem(
repo_path,
glob="**/*",
suffixes=[".py"],
exclude=["**/non-utf8-encoding.py"],
parser=LanguageParser(language=Language.PYTHON, parser_threshold=500),
)
documents = loader.load()
len(documents)
Each document is as described, but we still need to split each document into smaller chunks so we can later identify each chunk as context to provide as prompt context for LLM. For splitting, we can do the following:
from langchain.text_splitter import RecursiveCharacterTextSplitter
python_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
)
texts = python_splitter.split_documents(documents)
len(texts)
Here, the texts
object is structured in a way that allows us to identify each chunk belonging to which document. Later, we will use this object to create embeddings.
Embedding
To obtain embeddings for each text within the texts
object, we first load an embedding model and apply it to the texts. These embeddings act as a lookup table, crucial for finding code similar to our query and enabling the Language Model (LLM) to generate more informed responses based on this context. RAG incorporates related context into our prompts, leveraging embedding similarity to enrich the model's understanding.
There are various methods to load a model, such as using the OpenAI API
straightforwardly:
Alternatively, loading a custom model from the sentence-transformers
library—like "sentence-transformers/all-mpnet-base-v2"—can be done as follows:
from langchain_community.embeddings import HuggingFaceEmbeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cuda'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
Ensure you have the sentence_transformers
library installed. By default, HuggingFaceEmbeddings
utilizes the specified model, allowing simple instantiation with HuggingFaceEmbeddings()
.
To apply the embedding model to texts, use:
from langchain_community.vectorstores import Chroma
db = Chroma.from_documents(texts, HuggingFaceEmbeddings())
Here, db
is a Chroma database object. To perform actual searches within the Chroma database, define the metric and parameters using:
retriever = db.as_retriever(
search_type="mmr", # Can also try "similarity"
search_kwargs={"k": 4},
)
In this case, we employ the mmr
(Maximal Marginal Relevance) similarity, which simultaneously finds similar documents and ensures diversity in the retrieved results.
Chat with LLM
You can easily use OpenAI to create LLMs for interaction by calling:
However, if you're seeking a custom public LLM, like "HuggingFaceH4/zephyr-7b-beta," and want to load it locally on your GPU, you'll need to consider optimizations like load-in-8bit
and other adjustments managed by libraries such as accelerate
and bitsandbytes
. Make sure to have these libraries installed.
To achieve this, you can use the HuggingFacePipeline
class, which takes a Hugging Face pipeline and adapts it into the necessary LangChain LLM for chaining purposes:
from transformers import pipeline
from langchain.llms import HuggingFacePipeline
pipe = pipeline("text-generation",
model="HuggingFaceH4/zephyr-7b-beta",
device_map="auto",
model_kwargs={"load_in_8bit": True},
max_new_tokens=512,
top_k=30,
temperature=0.1,
repetition_penalty=1.03)
llm = HuggingFacePipeline(pipeline=pipe)
In this code snippet, a text-generation
pipeline is created using the transformers
library. The HuggingFacePipeline
object is then instantiated with this pipeline, providing an LLM object that can be used for interactive text generation within LangChain.
Another method to achieve this is by utilizing the HuggingFaceHub
class, which calls a model repository endpoint from Hugging Face:
from langchain_community.llms import HuggingFaceHub
llm = HuggingFaceHub(
repo_id="HuggingFaceH4/zephyr-7b-beta",
task="text-generation",
model_kwargs={
"max_new_tokens": 512,
"top_k": 30,
"temperature": 0.1,
"repetition_penalty": 1.03,
},
)
In this approach, we create an LLM that accesses the specified repository (HuggingFaceH4/zephyr-7b-beta
) for a text generation task. Unlike the method involving HuggingFacePipeline
, this does not require configurations like accelerate
and bitsandbytes
. However, it does rely on setting the HUGGINGFACEHUB_API_TOKEN
environment variable, which authorizes access to the Hugging Face model hub.
It's important to note the distinction between the HuggingFaceHub
and HuggingFacePipeline
classes in LangChain. HuggingFaceHub
is specifically designed to interact with model repositories hosted on Hugging Face, fetching models and configurations directly from the hub for specialized tasks like text generation.
Add Context to Prompt: Retrieve Relevant Doc
Let's imagine we want to find a relevant document for a given input
. Typically, we instruct our Language Model (LLM) by providing input like Implement new API class called DataProvider.
while adhering to our repository structure, extending our BaseAPI
, and so on. RAG (Retrieval-Augmented Generation) can assist in achieving this, with the first step being the identification of a relevant document for the input. Additionally, we need to refer to the chat_history
to ensure a coherent conversation with the LLM and identify relevant documents. How can we achieve this?
To structure our prompt for finding relevant documents based on input
and chat_history
, we can utilize the ChatPromptTemplate
class:
from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[
("placeholder", "{chat_history}"),
("user", "{input}"),
(
"user",
"Given the above conversation, generate a search query to look up to get information relevant to the conversation",
),
]
)
For generating a search query to lookup in the Chroma database, we have two scenarios: with or without chat history. We can use the RunnableBranch
class to handle conditions accordingly:
from langchain_core.runnables import RunnableBranch
from langchain_core.output_parsers import StrOutputParser
generate_search_query = RunnableBranch(
(
lambda x: not x.get("chat_history", False),
(lambda x: x["input"]),
),
prompt | llm | StrOutputParser(),
)
This generate_search_query
runnable handles scenarios with and without chat history. When there's no chat history, it passes the input directly to the LLM. When chat history is available, it utilizes the chat history along with the input in the prompt.
To demonstrate this, let's invoke the generate_search_query
:
The result will be:
'Implement new API class called DataProvider.'
Now, let's manually add chat history to the generate_search_query
chain:
chat_history = [
HumanMessage(content='Implement new API class called DataProvider.'),
AIMessage(content='DataGathering:\n\n def __init__(self):\n self.data = {}')
]
result2 = generate_search_query.invoke({
"input": "Add method get_data() to DataProvider class.",
"chat_history": chat_history
})
The result will generate a prompt combining: 1. Chat history 2. New input 3. Default prompt to ask the LLM for a new search query based on chat history 4. LLM's output of the search query
Human: Implement new API class called DataProvider. AI: DataGathering: def init(self): self.data = {} Human: Add method get_data() to DataProvider class. Human: Given the above conversation, generate a search query to look up to get information relevant to the conversation. AI: "How to implement a data provider class in Python with a dictionary for storing data and a get_data() method for retrieving it?"
However, you might find that we don't necessarily need the entire conversation transcript, just the final response from the AI. With the code snippet below, using the ChatOpenAI
class from langchain_openai
, you can achieve this efficiently:
By defining llm
using this approach, you automatically receive only the last AI-generated message. Conversely, when utilizing the Hugging Face Pipeline, as demonstrated with HuggingFacePipeline
from langchain.llms
, you can configure the behavior to specifically return the last AI message using return_full_text=False
. This can be set up within your pipeline instantiation, as illustrated in the code snippet below:
from transformers import pipeline
pipe = pipeline("text-generation",
model="HuggingFaceH4/zephyr-7b-beta",
# additional configuration...
return_full_text=False)
This setup ensures that the Hugging Face Pipeline operates to provide only the latest AI-generated response, aligning with your requirements. To retrieve relevant documents based on the generated search query, we can create a retrieval chain:
Alternatively, LangChain provides a convenient function to encapsulate these steps:
from langchain.chains import create_history_aware_retriever
retrieve_documents = create_history_aware_retriever(llm, retriever, prompt)
This function simplifies the process by creating a history-aware retrieval chain that integrates the LLM, document retriever, and the specified prompt for generating search queries based on input and chat history.
Chat with LLM Using Context
Now that we've established a document retrieval chain, we're moving towards creating a chain to answer user questions based on the retrieved documents as context. To achieve this, we'll use RunnablePassthrough
which allows us to pass input data to the chain. Our input data will consist of input, chat history, and context. The context here is a list of retrieved documents. We'll create a function to combine these elements using double line breaks \n\n
, defined as follows:
def get_page_content(inputs: dict) -> str:
return '\n\n'.join(
doc.page_content for doc in inputs['context']
)
Additionally, we'll define a chat prompt template like so:
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Answer the user's questions based on the below context:\n\n{context}",
),
("placeholder", "{chat_history}"),
("user", "{input}"),
]
)
To integrate these components into our chain, we'll use RunnablePassthrough
to first combine all documents together and then add them to the prompt. This will ensure that the system message, chat history, and user input are combined effectively to provide as input to our language model (llm) for generating the answer. The chain setup will look like this:
from langchain_core.runnables import Runnable, RunnablePassthrough
document_chain = (RunnablePassthrough.assign(context=get_page_content).with_config(
run_name="format_inputs"
) | prompt | llm | StrOutputParser()).with_config(run_name="stuff_documents_chain")
Suppose we have a lengthy conversation, and we'd like to trim the chat history to just the last two messages. We can achieve this using RunnablePassthrough
to trim the chat history while leaving other inputs unchanged. The trim_messages
function is designed to handle this task by keeping the last two messages and clearing the rest:
def trim_messages(chain_input):
stored_messages = chain_input['chat_history']
if len(stored_messages) <= 2:
return False
chain_input['chat_history'] = []
for message in stored_messages[-2:]:
chain_input['chat_history'].append(message)
return True
To ensure this trimming method is always called first in our chain, we'll integrate it as follows:
from langchain_core.runnables import Runnable, RunnablePassthrough
document_chain = (
RunnablePassthrough.assign(messages_trimmed=trim_messages)
| RunnablePassthrough.assign(context=get_page_content).with_config(run_name="format_inputs")
| prompt
| llm
| StrOutputParser()
).with_config(run_name="stuff_documents_chain")
In this setup, trim_messages
is invoked at the start of the chain using RunnablePassthrough
, which enables us to modify the chat history while preserving the integrity of other inputs like context
and input
. This ensures that only the necessary portion of the chat history is retained before proceeding with generating responses based on the provided context and user input.
To streamline the creation of the document chain, a shortcut has been implemented in langchain
as shown below:
from langchain.chains.combine_documents import create_stuff_documents_chain
# Creating the document chain without trimming the chat history
document_chain_without_trimming = create_stuff_documents_chain(llm, prompt)
# Integrating the chat history trimming step into the document chain
document_chain = (
RunnablePassthrough.assign(messages_trimmed=trim_messages)
| document_chain_without_trimming
)
In this updated snippet, the create_stuff_documents_chain
function from langchain
is used to quickly set up the document chain with the specified language model (llm
) and prompt. The original document chain is then enhanced by incorporating the trim_messages
step using RunnablePassthrough
, ensuring that the chat history is trimmed before the rest of the document processing takes place. This approach maintains clarity and efficiency in configuring the chain for generating responses based on the provided context and user input.
To invoke this chain, use the following method:
document_chain.invoke({"input": "Add method get_data() to DataProvider class.",
"chat_history": chat_history.messages,
"context": result2})
However, we can optimize this process further. Instead of manually passing input
and chat history
, we can merge everything together internally within the chain. We'll cover this optimization in the next section.
Connect Chains to Make Chain
Now, let's consolidate the process into a single chain combining document retrieval and user input handling. Instead of managing separate chains for document retrieval (retrieve_documents
) and responding to user input (document_chain
), we can streamline this by passing the user input directly to one comprehensive chain.
retrieval_chain = (
RunnablePassthrough.assign(
context=retrieve_documents.with_config(run_name="retrieve_documents")
)
.assign(answer=document_chain)
).with_config(run_name="retrieval_chain")
Alternatively, you can streamline the process using the pre-implemented method create_retrieval_chain
from langchain
. This method combines the document retrieval (retriever_chain
) and response generation (document_chain
) into a single unified chain.
from langchain.chains import create_retrieval_chain
# Creating the retrieval chain
retrieval_chain = create_retrieval_chain(retrieve_documents, document_chain)
By utilizing create_retrieval_chain
, you can efficiently create a comprehensive chain that handles both document retrieval and user input processing seamlessly. This approach simplifies the workflow and ensures that the necessary components are integrated for generating responses based on the provided context and user input.
To invoke this chain seamlessly, consider the following invocation:
result = retrieval_chain.invoke({
"input": "Add method set_data() to DataProvider class.",
"chat_history": chat_history.messages
})
As observed in the result dictionary, the keys include 'input'
, 'chat_history'
, 'context'
, and 'answer'
. Despite these assignments during chain execution, we still need to manually pass 'chat_history'
. How can we automate this process? Enter RunnableWithMessageHistory
, a class designed to enable automatic history tracking. This class leverages the placeholder ("placeholder", "{chat_history}")
within the prompt template previously established.
from langchain.memory import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
# If 'chat_history' is not yet defined
# chat_history = ChatMessageHistory()
retrieval_chain_with_message_history = RunnableWithMessageHistory(
retrieval_chain,
lambda session_id: chat_history,
input_messages_key="input",
output_messages_key="answer",
history_messages_key="chat_history",
)
By incorporating RunnableWithMessageHistory
, the chain gains the ability to automatically manage and save chat history. Here's how you can invoke it with a specified session ID:
result = retrieval_chain_with_message_history.invoke(
{"input": "Implement a new API class called DataGathering?"},
{"configurable": {"session_id": "unused"}}
)
Upon checking chat_history.messages
, you'll notice that messages have been added automatically. This behavior is achieved by specifying 'input_messages_key="input"'
and 'output_messages_key="answer"'
within the Runnable
, indicating which parts should be preserved as human and AI messages within the chat history.