From 0d3e3fc0afd8ddea7b55248a72b4fe167447229b Mon Sep 17 00:00:00 2001 From: Madhu Date: Mon, 23 Dec 2024 22:29:07 +0530 Subject: [PATCH 1/6] RAG with database routing - first initialization --- rag_tutorials/rag_database_routing/README.md | 14 + .../rag_database_routing.py | 252 ++++++++++++++++++ .../rag_database_routing/requirements.txt | 9 + 3 files changed, 275 insertions(+) create mode 100644 rag_tutorials/rag_database_routing/README.md create mode 100644 rag_tutorials/rag_database_routing/rag_database_routing.py create mode 100644 rag_tutorials/rag_database_routing/requirements.txt diff --git a/rag_tutorials/rag_database_routing/README.md b/rag_tutorials/rag_database_routing/README.md new file mode 100644 index 0000000..106c4b2 --- /dev/null +++ b/rag_tutorials/rag_database_routing/README.md @@ -0,0 +1,14 @@ +# RAG Database Router Demo + +This demo showcases RAG (Retrieval Augmented Generation) with database routing capabilities. The application allows users to: + +1. Upload documents to three different databases: + - Product Information + - Customer Support & FAQ + - Financial Information + +2. Query information using natural language, with automatic routing to the most relevant database. + +## Setup + +1. Create a virtual environment: diff --git a/rag_tutorials/rag_database_routing/rag_database_routing.py b/rag_tutorials/rag_database_routing/rag_database_routing.py new file mode 100644 index 0000000..e4ad9bf --- /dev/null +++ b/rag_tutorials/rag_database_routing/rag_database_routing.py @@ -0,0 +1,252 @@ +import os +from typing import List, Dict, Any, Literal +from dataclasses import dataclass +import streamlit as st +from dotenv import load_dotenv +from langchain_core.documents import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import PyPDFLoader +from langchain_community.vectorstores import Chroma +from langchain_community.embeddings import OpenAIEmbeddings +from langchain_openai import ChatOpenAI +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate +import tempfile + +# Load environment variables +load_dotenv() + +# Constants +DatabaseType = Literal["products", "customer_support", "financials"] +PERSIST_DIRECTORY = "db_storage" + +@dataclass +class Database: + """Class to represent a database configuration""" + name: str + description: str + collection_name: str + persist_directory: str + +# Database configurations +DATABASES: Dict[DatabaseType, Database] = { + "products": Database( + name="Product Information", + description="Product details, specifications, and features", + collection_name="products_db", + persist_directory=f"{PERSIST_DIRECTORY}/products" + ), + "customer_support": Database( + name="Customer Support & FAQ", + description="Customer support information, frequently asked questions, and guides", + collection_name="support_db", + persist_directory=f"{PERSIST_DIRECTORY}/support" + ), + "financials": Database( + name="Financial Information", + description="Financial data, revenue, costs, and liabilities", + collection_name="finance_db", + persist_directory=f"{PERSIST_DIRECTORY}/finance" + ) +} + +# Router prompt template +ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and route them to the most appropriate database. + +Available databases: +1. Product Information: Contains product details, specifications, and features +2. Customer Support & FAQ: Contains customer support information, frequently asked questions, and guides +3. Financial Information: Contains financial data, revenue, costs, and liabilities + +User question: {question} + +Return only one of these exact strings: +- products +- customer_support +- financials + +Your response:""" + +def init_session_state(): + """Initialize session state variables""" + if 'databases' not in st.session_state: + st.session_state.databases = {} + if 'embeddings' not in st.session_state: + st.session_state.embeddings = OpenAIEmbeddings() + if 'llm' not in st.session_state: + st.session_state.llm = ChatOpenAI(temperature=0) + if 'router_chain' not in st.session_state: + router_prompt = PromptTemplate( + template=ROUTER_TEMPLATE, + input_variables=["question"] + ) + st.session_state.router_chain = LLMChain( + llm=st.session_state.llm, + prompt=router_prompt + ) + +def process_document(file) -> List[Document]: + """Process uploaded PDF document""" + try: + with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: + tmp_file.write(file.getvalue()) + tmp_path = tmp_file.name + + loader = PyPDFLoader(tmp_path) + documents = loader.load() + + # Clean up temporary file + os.unlink(tmp_path) + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200 + ) + texts = text_splitter.split_documents(documents) + + return texts + except Exception as e: + st.error(f"Error processing document: {e}") + return [] + +def get_or_create_db(db_type: DatabaseType) -> Chroma: + """Get or create a database for the specified type with proper initialization and error handling""" + try: + if db_type not in st.session_state.databases: + db_config = DATABASES[db_type] + + # Ensure directory exists + os.makedirs(db_config.persist_directory, exist_ok=True) + + # Initialize Chroma with proper settings + st.session_state.databases[db_type] = Chroma( + persist_directory=db_config.persist_directory, + embedding_function=st.session_state.embeddings, + collection_name=db_config.collection_name, + collection_metadata={ + "description": db_config.description, + "database_type": db_type + } + ) + + # Log successful initialization + st.success(f"Initialized {db_config.name} database") + + return st.session_state.databases[db_type] + + except Exception as e: + st.error(f"Error initializing {db_type} database: {str(e)}") + raise + +def route_query(question: str) -> DatabaseType: + """Route the question to the appropriate database""" + response = st.session_state.router_chain.invoke({"question": question}) + return response["text"].strip().lower() + +def query_database(db: Chroma, question: str) -> str: + """Query the database and return the response""" + docs = db.similarity_search(question, k=3) + + context = "\n\n".join([doc.page_content for doc in docs]) + + prompt = PromptTemplate( + template="""Answer the question based on the following context. If you cannot answer the question based on the context, say "I don't have enough information to answer this question." + +Context: {context} + +Question: {question} + +Answer:""", + input_variables=["context", "question"] + ) + + chain = LLMChain(llm=st.session_state.llm, prompt=prompt) + response = chain.invoke({"context": context, "question": question}) + return response["text"] + +def clear_database(db_type: DatabaseType = None): + """Clear specified database or all databases if none specified""" + try: + if db_type: + if db_type in st.session_state.databases: + db_config = DATABASES[db_type] + # Delete collection + st.session_state.databases[db_type]._collection.delete() + # Remove from session state + del st.session_state.databases[db_type] + # Clean up persist directory + if os.path.exists(db_config.persist_directory): + import shutil + shutil.rmtree(db_config.persist_directory) + st.success(f"Cleared {db_config.name} database") + else: + # Clear all databases + for db_type, db_config in DATABASES.items(): + if db_type in st.session_state.databases: + st.session_state.databases[db_type]._collection.delete() + if os.path.exists(db_config.persist_directory): + import shutil + shutil.rmtree(db_config.persist_directory) + st.session_state.databases = {} + st.success("Cleared all databases") + except Exception as e: + st.error(f"Error clearing database(s): {str(e)}") + +def main(): + st.title("📚 RAG Database Router ") + + init_session_state() + + # Sidebar for database management + with st.sidebar: + st.header("Database Management") + if st.button("Clear All Databases"): + clear_database() + + st.divider() + st.subheader("Clear Individual Databases") + for db_type, db_config in DATABASES.items(): + if st.button(f"Clear {db_config.name}"): + clear_database(db_type) + + # Document upload section + st.header("Document Upload") + tabs = st.tabs([db.name for db in DATABASES.values()]) + + for (db_type, db_config), tab in zip(DATABASES.items(), tabs): + with tab: + st.write(db_config.description) + uploaded_file = st.file_uploader( + "Upload PDF document", + type="pdf", + key=f"upload_{db_type}" + ) + + if uploaded_file: + with st.spinner('Processing document...'): + texts = process_document(uploaded_file) + if texts: + db = get_or_create_db(db_type) + db.add_documents(texts) + st.success("Document processed and added to the database!") + + # Query section + st.header("Ask Questions") + question = st.text_input("Enter your question:") + + if question: + with st.spinner('Finding answer...'): + # Route the question + db_type = route_query(question) + db = get_or_create_db(db_type) + + # Display routing information + st.info(f"Routing question to: {DATABASES[db_type].name}") + + # Get and display answer + answer = query_database(db, question) + st.write("### Answer") + st.write(answer) + +if __name__ == "__main__": + main() diff --git a/rag_tutorials/rag_database_routing/requirements.txt b/rag_tutorials/rag_database_routing/requirements.txt new file mode 100644 index 0000000..0ce6d76 --- /dev/null +++ b/rag_tutorials/rag_database_routing/requirements.txt @@ -0,0 +1,9 @@ +langchain>=0.1.0 +langchain-community>=0.0.10 +langchain-core>=0.1.10 +chromadb>=0.4.22 +streamlit>=1.29.0 +python-dotenv>=1.0.0 +pypdf>=4.0.0 +sentence-transformers>=2.2.2 +openai>=1.6.1 From 29251a2c5db908edd895a2aa4ba340158977d85b Mon Sep 17 00:00:00 2001 From: Madhu Date: Tue, 24 Dec 2024 21:31:24 +0530 Subject: [PATCH 2/6] simple implementation of chain based - db routing --- .../rag_database_routing.py | 307 ++++++++++-------- 1 file changed, 171 insertions(+), 136 deletions(-) diff --git a/rag_tutorials/rag_database_routing/rag_database_routing.py b/rag_tutorials/rag_database_routing/rag_database_routing.py index e4ad9bf..fcd7723 100644 --- a/rag_tutorials/rag_database_routing/rag_database_routing.py +++ b/rag_tutorials/rag_database_routing/rag_database_routing.py @@ -1,4 +1,5 @@ import os +import getpass from typing import List, Dict, Any, Literal from dataclasses import dataclass import streamlit as st @@ -7,51 +8,35 @@ from langchain_core.documents import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import PyPDFLoader from langchain_community.vectorstores import Chroma -from langchain_community.embeddings import OpenAIEmbeddings -from langchain_openai import ChatOpenAI +from langchain_openai import OpenAIEmbeddings from langchain.chains import LLMChain -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI import tempfile +from langchain_core.runnables import RunnableSequence +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate +from langchain_chroma import Chroma -# Load environment variables -load_dotenv() +def init_session_state(): + """Initialize session state variables""" + if 'openai_api_key' not in st.session_state: + st.session_state.openai_api_key = "" + if 'embeddings' not in st.session_state: + st.session_state.embeddings = None + if 'llm' not in st.session_state: + st.session_state.llm = None + if 'databases' not in st.session_state: + st.session_state.databases = {} + +# Initialize session state at the top +init_session_state() # Constants DatabaseType = Literal["products", "customer_support", "financials"] PERSIST_DIRECTORY = "db_storage" -@dataclass -class Database: - """Class to represent a database configuration""" - name: str - description: str - collection_name: str - persist_directory: str - -# Database configurations -DATABASES: Dict[DatabaseType, Database] = { - "products": Database( - name="Product Information", - description="Product details, specifications, and features", - collection_name="products_db", - persist_directory=f"{PERSIST_DIRECTORY}/products" - ), - "customer_support": Database( - name="Customer Support & FAQ", - description="Customer support information, frequently asked questions, and guides", - collection_name="support_db", - persist_directory=f"{PERSIST_DIRECTORY}/support" - ), - "financials": Database( - name="Financial Information", - description="Financial data, revenue, costs, and liabilities", - collection_name="finance_db", - persist_directory=f"{PERSIST_DIRECTORY}/finance" - ) -} - -# Router prompt template -ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and route them to the most appropriate database. +ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and determine which databases might contain relevant information. Available databases: 1. Product Information: Contains product details, specifications, and features @@ -60,33 +45,85 @@ Available databases: User question: {question} -Return only one of these exact strings: +Return a comma-separated list of relevant databases (no spaces after commas). Only use these exact strings: - products - customer_support - financials +For example: "products,customer_support" if the question relates to both product info and support. Your response:""" -def init_session_state(): - """Initialize session state variables""" - if 'databases' not in st.session_state: - st.session_state.databases = {} - if 'embeddings' not in st.session_state: - st.session_state.embeddings = OpenAIEmbeddings() - if 'llm' not in st.session_state: - st.session_state.llm = ChatOpenAI(temperature=0) - if 'router_chain' not in st.session_state: - router_prompt = PromptTemplate( - template=ROUTER_TEMPLATE, - input_variables=["question"] - ) - st.session_state.router_chain = LLMChain( - llm=st.session_state.llm, - prompt=router_prompt - ) +@dataclass +class CollectionConfig: + name: str + description: str + collection_name: str + persist_directory: str + +# Collection configurations +COLLECTIONS: Dict[DatabaseType, CollectionConfig] = { + "products": CollectionConfig( + name="Product Information", + description="Product details, specifications, and features", + collection_name="products_collection", + persist_directory=f"{PERSIST_DIRECTORY}/products" + ), + "customer_support": CollectionConfig( + name="Customer Support & FAQ", + description="Customer support information, frequently asked questions, and guides", + collection_name="support_collection", + persist_directory=f"{PERSIST_DIRECTORY}/support" + ), + "financials": CollectionConfig( + name="Financial Information", + description="Financial data, revenue, costs, and liabilities", + collection_name="finance_collection", + persist_directory=f"{PERSIST_DIRECTORY}/finance" + ) +} + +def initialize_models(): + """Initialize OpenAI models with API key""" + if st.session_state.openai_api_key: + try: + os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key + # Test the API key with a small embedding request + test_embeddings = OpenAIEmbeddings(model="text-embedding-3-large") + test_embeddings.embed_query("test") + + # If successful, initialize the models + st.session_state.embeddings = test_embeddings + st.session_state.llm = ChatOpenAI(temperature=0) + st.session_state.databases = { + "products": Chroma( + collection_name=COLLECTIONS["products"].collection_name, + embedding_function=st.session_state.embeddings, + persist_directory=COLLECTIONS["products"].persist_directory + ), + "customer_support": Chroma( + collection_name=COLLECTIONS["customer_support"].collection_name, + embedding_function=st.session_state.embeddings, + persist_directory=COLLECTIONS["customer_support"].persist_directory + ), + "financials": Chroma( + collection_name=COLLECTIONS["financials"].collection_name, + embedding_function=st.session_state.embeddings, + persist_directory=COLLECTIONS["financials"].persist_directory + ) + } + return True + except Exception as e: + st.error(f"Error connecting to OpenAI API: {str(e)}") + st.error("Please check your internet connection and API key.") + return False + return False def process_document(file) -> List[Document]: """Process uploaded PDF document""" + if not st.session_state.embeddings: + st.error("OpenAI API connection not initialized. Please check your API key.") + return [] + try: with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: tmp_file.write(file.getvalue()) @@ -109,124 +146,123 @@ def process_document(file) -> List[Document]: st.error(f"Error processing document: {e}") return [] -def get_or_create_db(db_type: DatabaseType) -> Chroma: - """Get or create a database for the specified type with proper initialization and error handling""" - try: - if db_type not in st.session_state.databases: - db_config = DATABASES[db_type] - - # Ensure directory exists - os.makedirs(db_config.persist_directory, exist_ok=True) - - # Initialize Chroma with proper settings - st.session_state.databases[db_type] = Chroma( - persist_directory=db_config.persist_directory, - embedding_function=st.session_state.embeddings, - collection_name=db_config.collection_name, - collection_metadata={ - "description": db_config.description, - "database_type": db_type - } - ) - - # Log successful initialization - st.success(f"Initialized {db_config.name} database") - - return st.session_state.databases[db_type] - - except Exception as e: - st.error(f"Error initializing {db_type} database: {str(e)}") - raise +def route_query(question: str) -> List[DatabaseType]: + """Route the question to appropriate databases""" + router_prompt = ChatPromptTemplate.from_template(ROUTER_TEMPLATE) + router_chain = router_prompt | st.session_state.llm | StrOutputParser() + response = router_chain.invoke({"question": question}) + return response.strip().lower().split(",") -def route_query(question: str) -> DatabaseType: - """Route the question to the appropriate database""" - response = st.session_state.router_chain.invoke({"question": question}) - return response["text"].strip().lower() - -def query_database(db: Chroma, question: str) -> str: - """Query the database and return the response""" - docs = db.similarity_search(question, k=3) +def query_multiple_databases(question: str) -> str: + """Query multiple relevant databases and combine results""" + database_types = route_query(question) + all_docs = [] - context = "\n\n".join([doc.page_content for doc in docs]) + # Collect relevant documents from each database + for db_type in database_types: + db = st.session_state.databases[db_type] + docs = db.similarity_search(question, k=2) # Reduced k since we're querying multiple DBs + all_docs.extend(docs) - prompt = PromptTemplate( - template="""Answer the question based on the following context. If you cannot answer the question based on the context, say "I don't have enough information to answer this question." + # Sort all documents by relevance score if available + # Note: You might need to modify this based on your similarity search implementation + context = "\n\n---\n\n".join([doc.page_content for doc in all_docs]) + + answer_prompt = ChatPromptTemplate.from_template( + """Answer the question based on the following context from multiple databases. + If you use information from multiple sources, please indicate which type of source it came from. + If you cannot answer the question based on the context, say "I don't have enough information to answer this question." Context: {context} Question: {question} -Answer:""", - input_variables=["context", "question"] +Answer:""" ) - chain = LLMChain(llm=st.session_state.llm, prompt=prompt) - response = chain.invoke({"context": context, "question": question}) - return response["text"] + answer_chain = answer_prompt | st.session_state.llm | StrOutputParser() + return answer_chain.invoke({"context": context, "question": question}) -def clear_database(db_type: DatabaseType = None): - """Clear specified database or all databases if none specified""" +def clear_collection(collection_type: DatabaseType = None): + """Clear specified collection or all collections if none specified""" try: - if db_type: - if db_type in st.session_state.databases: - db_config = DATABASES[db_type] + if collection_type: + if collection_type in st.session_state.databases: + collection_config = COLLECTIONS[collection_type] # Delete collection - st.session_state.databases[db_type]._collection.delete() + st.session_state.databases[collection_type]._collection.delete() # Remove from session state - del st.session_state.databases[db_type] + del st.session_state.databases[collection_type] # Clean up persist directory - if os.path.exists(db_config.persist_directory): + if os.path.exists(collection_config.persist_directory): import shutil - shutil.rmtree(db_config.persist_directory) - st.success(f"Cleared {db_config.name} database") + shutil.rmtree(collection_config.persist_directory) + st.success(f"Cleared {collection_config.name} collection") else: - # Clear all databases - for db_type, db_config in DATABASES.items(): - if db_type in st.session_state.databases: - st.session_state.databases[db_type]._collection.delete() - if os.path.exists(db_config.persist_directory): + # Clear all collections + for collection_type, collection_config in COLLECTIONS.items(): + if collection_type in st.session_state.databases: + st.session_state.databases[collection_type]._collection.delete() + if os.path.exists(collection_config.persist_directory): import shutil - shutil.rmtree(db_config.persist_directory) + shutil.rmtree(collection_config.persist_directory) st.session_state.databases = {} - st.success("Cleared all databases") + st.success("Cleared all collections") except Exception as e: - st.error(f"Error clearing database(s): {str(e)}") + st.error(f"Error clearing collection(s): {str(e)}") def main(): - st.title("📚 RAG Database Router ") - - init_session_state() - - # Sidebar for database management + st.title("📚 RAG with Database Routing") + with st.sidebar: + st.header("Configuration") + api_key = st.text_input( + "Enter OpenAI API Key:", + type="password", + value=st.session_state.openai_api_key, + key="api_key_input" + ) + + if api_key: + st.session_state.openai_api_key = api_key + if initialize_models(): + st.success("API Key set successfully!") + else: + st.error("Invalid API Key") + + if not st.session_state.openai_api_key: + st.warning("Please enter your OpenAI API key to continue") + st.stop() + + st.divider() st.header("Database Management") if st.button("Clear All Databases"): - clear_database() + clear_collection() st.divider() st.subheader("Clear Individual Databases") - for db_type, db_config in DATABASES.items(): - if st.button(f"Clear {db_config.name}"): - clear_database(db_type) + for collection_type, collection_config in COLLECTIONS.items(): + if st.button(f"Clear {collection_config.name}"): + clear_collection(collection_type) # Document upload section st.header("Document Upload") - tabs = st.tabs([db.name for db in DATABASES.values()]) + tabs = st.tabs([collection_config.name for collection_config in COLLECTIONS.values()]) - for (db_type, db_config), tab in zip(DATABASES.items(), tabs): + for (collection_type, collection_config), tab in zip(COLLECTIONS.items(), tabs): with tab: - st.write(db_config.description) + st.write(collection_config.description) uploaded_file = st.file_uploader( "Upload PDF document", type="pdf", - key=f"upload_{db_type}" + key=f"upload_{collection_type}" ) if uploaded_file: with st.spinner('Processing document...'): texts = process_document(uploaded_file) if texts: - db = get_or_create_db(db_type) + db = st.session_state.databases[collection_type] db.add_documents(texts) st.success("Document processed and added to the database!") @@ -236,15 +272,14 @@ def main(): if question: with st.spinner('Finding answer...'): - # Route the question - db_type = route_query(question) - db = get_or_create_db(db_type) + # Get relevant databases + database_types = route_query(question) # Display routing information - st.info(f"Routing question to: {DATABASES[db_type].name}") + st.info(f"Searching in: {', '.join([COLLECTIONS[db_type].name for db_type in database_types])}") # Get and display answer - answer = query_database(db, question) + answer = query_multiple_databases(question) st.write("### Answer") st.write(answer) From 7035e9e64186c130020f69ca1e1112d75c357149 Mon Sep 17 00:00:00 2001 From: Madhu Date: Wed, 25 Dec 2024 02:21:40 +0530 Subject: [PATCH 3/6] added everything - testing time --- rag_tutorials/rag_database_routing/README.md | 50 ++- .../rag_database_routing.py | 314 ++++++++---------- .../rag_database_routing/requirements.txt | 14 +- 3 files changed, 184 insertions(+), 194 deletions(-) diff --git a/rag_tutorials/rag_database_routing/README.md b/rag_tutorials/rag_database_routing/README.md index 106c4b2..279f8a6 100644 --- a/rag_tutorials/rag_database_routing/README.md +++ b/rag_tutorials/rag_database_routing/README.md @@ -1,6 +1,6 @@ -# RAG Database Router Demo +# RAG Agent with Database Routing -This demo showcases RAG (Retrieval Augmented Generation) with database routing capabilities. The application allows users to: +This project showcases the RAG with database routing capabilities - which is a very efficient way to retrieve information from a large set of documents. The application allows users to: 1. Upload documents to three different databases: - Product Information @@ -9,6 +9,48 @@ This demo showcases RAG (Retrieval Augmented Generation) with database routing c 2. Query information using natural language, with automatic routing to the most relevant database. -## Setup +## Features -1. Create a virtual environment: +- **Document Upload**: Users can upload multiple PDF documents related to a particular company. These documents are processed and stored in one of the three databases: Product Information, Customer Support & FAQ, or Financial Information. + +- **Natural Language Querying**: Users can ask questions in natural language. The system automatically routes the query to the most relevant database using a phidata agent as the router. + +- **RAG Orchestration**: Utilizes Langchain for orchestrating the retrieval augmented generation process, ensuring that the most relevant information is retrieved and presented to the user. + +- **Fallback Mechanism**: If no relevant documents are found in the databases, a LangGraph agent with a DuckDuckGo search tool is used to perform web research and provide an answer. + +- **User Interface**: Built with Streamlit, providing an intuitive and interactive user experience. + +## How to Run? + +1. **Clone the Repository**: + ```bash + git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git + cd rag_tutorials/rag_database_routing + ``` + +2. **Install Dependencies**: + ```bash + pip install -r requirements.txt + ``` + +3. **Run the Application**: + ```bash + streamlit run rag_database_routing.py + ``` + +4. **Configure API Key**: Obtain an OpenAI API key and set it in the application. This is required for initializing the language models used in the application. + +5. **Upload Documents**: Use the document upload section to add PDF documents to the desired database. + +6. **Ask Questions**: Enter your questions in the query section. The application will route your question to the appropriate database and provide an answer. + +## Technologies Used + +- **Langchain**: For RAG orchestration, ensuring efficient retrieval and generation of information. +- **Phidata Agent**: Used as the router agent to determine the most relevant database for a given query. +- **LangGraph Agent**: Acts as a fallback mechanism, utilizing DuckDuckGo for web research when necessary. +- **Streamlit**: Provides a user-friendly interface for document upload and querying. +- **ChromaDB**: Used for managing the databases, storing and retrieving document embeddings efficiently. + +This application is designed to streamline the process of retrieving information from large sets of documents, making it easier for users to find the answers they need quickly and efficiently. diff --git a/rag_tutorials/rag_database_routing/rag_database_routing.py b/rag_tutorials/rag_database_routing/rag_database_routing.py index fcd7723..6582225 100644 --- a/rag_tutorials/rag_database_routing/rag_database_routing.py +++ b/rag_tutorials/rag_database_routing/rag_database_routing.py @@ -1,25 +1,24 @@ import os -import getpass -from typing import List, Dict, Any, Literal +from typing import List, Dict, Literal from dataclasses import dataclass import streamlit as st -from dotenv import load_dotenv from langchain_core.documents import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import PyPDFLoader from langchain_community.vectorstores import Chroma -from langchain_openai import OpenAIEmbeddings -from langchain.chains import LLMChain -from langchain_core.prompts import PromptTemplate -from langchain_openai import ChatOpenAI +from langchain_openai import OpenAIEmbeddings, ChatOpenAI import tempfile -from langchain_core.runnables import RunnableSequence -from langchain_core.output_parsers import StrOutputParser -from langchain_core.prompts import ChatPromptTemplate -from langchain_chroma import Chroma +from phi.agent import Agent +from phi.model.openai import OpenAIChat +from langchain.schema import HumanMessage +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain.chains import create_retrieval_chain +from langchain import hub +from langgraph.prebuilt import create_react_agent +from langchain_community.tools import DuckDuckGoSearchRun +from langchain_core.language_models import BaseLanguageModel def init_session_state(): - """Initialize session state variables""" if 'openai_api_key' not in st.session_state: st.session_state.openai_api_key = "" if 'embeddings' not in st.session_state: @@ -29,30 +28,11 @@ def init_session_state(): if 'databases' not in st.session_state: st.session_state.databases = {} -# Initialize session state at the top init_session_state() -# Constants -DatabaseType = Literal["products", "customer_support", "financials"] +DatabaseType = Literal["products", "support", "finance"] PERSIST_DIRECTORY = "db_storage" -ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and determine which databases might contain relevant information. - -Available databases: -1. Product Information: Contains product details, specifications, and features -2. Customer Support & FAQ: Contains customer support information, frequently asked questions, and guides -3. Financial Information: Contains financial data, revenue, costs, and liabilities - -User question: {question} - -Return a comma-separated list of relevant databases (no spaces after commas). Only use these exact strings: -- products -- customer_support -- financials - -For example: "products,customer_support" if the question relates to both product info and support. -Your response:""" - @dataclass class CollectionConfig: name: str @@ -60,7 +40,6 @@ class CollectionConfig: collection_name: str persist_directory: str -# Collection configurations COLLECTIONS: Dict[DatabaseType, CollectionConfig] = { "products": CollectionConfig( name="Product Information", @@ -68,13 +47,13 @@ COLLECTIONS: Dict[DatabaseType, CollectionConfig] = { collection_name="products_collection", persist_directory=f"{PERSIST_DIRECTORY}/products" ), - "customer_support": CollectionConfig( + "support": CollectionConfig( name="Customer Support & FAQ", description="Customer support information, frequently asked questions, and guides", collection_name="support_collection", persist_directory=f"{PERSIST_DIRECTORY}/support" ), - "financials": CollectionConfig( + "finance": CollectionConfig( name="Financial Information", description="Financial data, revenue, costs, and liabilities", collection_name="finance_collection", @@ -83,47 +62,25 @@ COLLECTIONS: Dict[DatabaseType, CollectionConfig] = { } def initialize_models(): - """Initialize OpenAI models with API key""" if st.session_state.openai_api_key: - try: - os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key - # Test the API key with a small embedding request - test_embeddings = OpenAIEmbeddings(model="text-embedding-3-large") - test_embeddings.embed_query("test") - - # If successful, initialize the models - st.session_state.embeddings = test_embeddings - st.session_state.llm = ChatOpenAI(temperature=0) - st.session_state.databases = { - "products": Chroma( - collection_name=COLLECTIONS["products"].collection_name, - embedding_function=st.session_state.embeddings, - persist_directory=COLLECTIONS["products"].persist_directory - ), - "customer_support": Chroma( - collection_name=COLLECTIONS["customer_support"].collection_name, - embedding_function=st.session_state.embeddings, - persist_directory=COLLECTIONS["customer_support"].persist_directory - ), - "financials": Chroma( - collection_name=COLLECTIONS["financials"].collection_name, - embedding_function=st.session_state.embeddings, - persist_directory=COLLECTIONS["financials"].persist_directory - ) - } - return True - except Exception as e: - st.error(f"Error connecting to OpenAI API: {str(e)}") - st.error("Please check your internet connection and API key.") - return False + os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key + st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-large") + st.session_state.llm = ChatOpenAI(temperature=0) + + for config in COLLECTIONS.values(): + os.makedirs(config.persist_directory, exist_ok=True) + + st.session_state.databases = { + db_type: Chroma( + collection_name=config.collection_name, + embedding_function=st.session_state.embeddings, + persist_directory=config.persist_directory + ) for db_type, config in COLLECTIONS.items() + } + return True return False def process_document(file) -> List[Document]: - """Process uploaded PDF document""" - if not st.session_state.embeddings: - st.error("OpenAI API connection not initialized. Please check your API key.") - return [] - try: with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: tmp_file.write(file.getvalue()) @@ -131,97 +88,99 @@ def process_document(file) -> List[Document]: loader = PyPDFLoader(tmp_path) documents = loader.load() - - # Clean up temporary file os.unlink(tmp_path) - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=1000, - chunk_overlap=200 - ) - texts = text_splitter.split_documents(documents) - - return texts + text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=300) + return text_splitter.split_documents(documents) except Exception as e: st.error(f"Error processing document: {e}") return [] -def route_query(question: str) -> List[DatabaseType]: - """Route the question to appropriate databases""" - router_prompt = ChatPromptTemplate.from_template(ROUTER_TEMPLATE) - router_chain = router_prompt | st.session_state.llm | StrOutputParser() - response = router_chain.invoke({"question": question}) - return response.strip().lower().split(",") - -def query_multiple_databases(question: str) -> str: - """Query multiple relevant databases and combine results""" - database_types = route_query(question) - all_docs = [] - - # Collect relevant documents from each database - for db_type in database_types: - db = st.session_state.databases[db_type] - docs = db.similarity_search(question, k=2) # Reduced k since we're querying multiple DBs - all_docs.extend(docs) - - # Sort all documents by relevance score if available - # Note: You might need to modify this based on your similarity search implementation - context = "\n\n---\n\n".join([doc.page_content for doc in all_docs]) - - answer_prompt = ChatPromptTemplate.from_template( - """Answer the question based on the following context from multiple databases. - If you use information from multiple sources, please indicate which type of source it came from. - If you cannot answer the question based on the context, say "I don't have enough information to answer this question." - -Context: {context} - -Question: {question} - -Answer:""" +def create_routing_agent() -> Agent: + return Agent( + model=OpenAIChat(id="gpt-4o", api_key=st.session_state.openai_api_key), + tools=[], + description="You are a query routing expert. Your only job is to analyze questions and determine which database they should be routed to.", + instructions=[ + "1. For questions about products, return 'products'", + "2. For questions about support, return 'support'", + "3. For questions about finance, return 'finance'", + "4. Return ONLY the database name" + ], + markdown=False, + show_tool_calls=False ) - - answer_chain = answer_prompt | st.session_state.llm | StrOutputParser() - return answer_chain.invoke({"context": context, "question": question}) -def clear_collection(collection_type: DatabaseType = None): - """Clear specified collection or all collections if none specified""" +def route_query(question: str) -> DatabaseType: try: - if collection_type: - if collection_type in st.session_state.databases: - collection_config = COLLECTIONS[collection_type] - # Delete collection - st.session_state.databases[collection_type]._collection.delete() - # Remove from session state - del st.session_state.databases[collection_type] - # Clean up persist directory - if os.path.exists(collection_config.persist_directory): - import shutil - shutil.rmtree(collection_config.persist_directory) - st.success(f"Cleared {collection_config.name} collection") - else: - # Clear all collections - for collection_type, collection_config in COLLECTIONS.items(): - if collection_type in st.session_state.databases: - st.session_state.databases[collection_type]._collection.delete() - if os.path.exists(collection_config.persist_directory): - import shutil - shutil.rmtree(collection_config.persist_directory) - st.session_state.databases = {} - st.success("Cleared all collections") + routing_agent = create_routing_agent() + response = routing_agent.run(question) + db_type = response.content.strip().lower().translate(str.maketrans('', '', '`\'"')) + + if db_type not in COLLECTIONS: + st.warning(f"Invalid database type: {db_type}, defaulting to products") + return "products" + + st.info(f"Routing question to {db_type} database") + return db_type except Exception as e: - st.error(f"Error clearing collection(s): {str(e)}") + st.error(f"Routing error: {str(e)}") + return "products" + +def create_fallback_agent(chat_model: BaseLanguageModel): + def web_research(query: str) -> str: + try: + search = DuckDuckGoSearchRun(num_results=5) + return search.run(query) + except Exception as e: + return f"Search failed: {str(e)}. Providing answer based on general knowledge." + + tools = [web_research] + return create_react_agent(model=chat_model, tools=tools, debug=False) + +def query_database(db: Chroma, question: str) -> tuple[str, list]: + try: + retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"k": 4, "score_threshold": 0.4}) + relevant_docs = retriever.get_relevant_documents(question) + + if relevant_docs: + retrieval_qa_prompt = hub.pull("langchain-ai/retrieval-qa-chat") + combine_docs_chain = create_stuff_documents_chain(st.session_state.llm, retrieval_qa_prompt) + retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain) + response = retrieval_chain.invoke({"input": question}) + return response['answer'], relevant_docs + + return _handle_web_fallback(question) + except Exception as e: + st.error(f"Error: {str(e)}") + return "I encountered an error. Please try rephrasing your question.", [] + +def _handle_web_fallback(question: str) -> tuple[str, list]: + st.info("No relevant documents found. Searching web...") + fallback_agent = create_fallback_agent(st.session_state.llm) + + with st.spinner('Researching...'): + agent_input = { + "messages": [HumanMessage(content=f"Research and provide a detailed answer for: '{question}'")], + "is_last_step": False + } + + try: + response = fallback_agent.invoke(agent_input, config={"recursion_limit": 100}) + if isinstance(response, dict) and "messages" in response: + answer = response["messages"][-1].content + return f"Web Search Result:\n{answer}", [] + except Exception: + fallback_response = st.session_state.llm.invoke(question).content + return f"Web search unavailable. General response: {fallback_response}", [] def main(): - st.title("📚 RAG with Database Routing") - + st.set_page_config(page_title="RAG Agent with Database Routing", page_icon="📚") + st.title("📚 RAG Agent with Database Routing") + with st.sidebar: st.header("Configuration") - api_key = st.text_input( - "Enter OpenAI API Key:", - type="password", - value=st.session_state.openai_api_key, - key="api_key_input" - ) + api_key = st.text_input("Enter OpenAI API Key:", type="password", value=st.session_state.openai_api_key, key="api_key_input") if api_key: st.session_state.openai_api_key = api_key @@ -233,53 +192,40 @@ def main(): if not st.session_state.openai_api_key: st.warning("Please enter your OpenAI API key to continue") st.stop() - - st.divider() - st.header("Database Management") - if st.button("Clear All Databases"): - clear_collection() - - st.divider() - st.subheader("Clear Individual Databases") - for collection_type, collection_config in COLLECTIONS.items(): - if st.button(f"Clear {collection_config.name}"): - clear_collection(collection_type) + + st.markdown("---") - # Document upload section st.header("Document Upload") - tabs = st.tabs([collection_config.name for collection_config in COLLECTIONS.values()]) + st.info("Upload documents to populate the databases. Each tab corresponds to a different database.") + tabs = st.tabs([config.name for config in COLLECTIONS.values()]) - for (collection_type, collection_config), tab in zip(COLLECTIONS.items(), tabs): + for (collection_type, config), tab in zip(COLLECTIONS.items(), tabs): with tab: - st.write(collection_config.description) - uploaded_file = st.file_uploader( - "Upload PDF document", - type="pdf", - key=f"upload_{collection_type}" - ) + st.write(config.description) + uploaded_files = st.file_uploader(f"Upload PDF documents to {config.name}", type="pdf", key=f"upload_{collection_type}", accept_multiple_files=True) - if uploaded_file: - with st.spinner('Processing document...'): - texts = process_document(uploaded_file) - if texts: + if uploaded_files: + with st.spinner('Processing documents...'): + all_texts = [] + for uploaded_file in uploaded_files: + texts = process_document(uploaded_file) + all_texts.extend(texts) + + if all_texts: db = st.session_state.databases[collection_type] - db.add_documents(texts) - st.success("Document processed and added to the database!") + db.add_documents(all_texts) + st.success("Documents processed and added to the database!") - # Query section st.header("Ask Questions") + st.info("Enter your question below to find answers from the relevant database.") question = st.text_input("Enter your question:") if question: with st.spinner('Finding answer...'): - # Get relevant databases - database_types = route_query(question) - - # Display routing information - st.info(f"Searching in: {', '.join([COLLECTIONS[db_type].name for db_type in database_types])}") - - # Get and display answer - answer = query_multiple_databases(question) + collection_type = route_query(question) + db = st.session_state.databases[collection_type] + st.info(f"Routing question to: {COLLECTIONS[collection_type].name}") + answer, relevant_docs = query_database(db, question) st.write("### Answer") st.write(answer) diff --git a/rag_tutorials/rag_database_routing/requirements.txt b/rag_tutorials/rag_database_routing/requirements.txt index 0ce6d76..c0259ec 100644 --- a/rag_tutorials/rag_database_routing/requirements.txt +++ b/rag_tutorials/rag_database_routing/requirements.txt @@ -1,9 +1,11 @@ -langchain>=0.1.0 -langchain-community>=0.0.10 -langchain-core>=0.1.10 -chromadb>=0.4.22 +langchain==0.3.12 +langchain-community==0.3.12 +langchain-core==0.3.28 +chromadb==0.5.20 streamlit>=1.29.0 -python-dotenv>=1.0.0 pypdf>=4.0.0 sentence-transformers>=2.2.2 -openai>=1.6.1 +phidata==2.7.3 +langchain-openai==0.2.14 +langgraph==0.2.53 +duckduckgo-search==6.4.1 \ No newline at end of file From d0c0798711b38e942545580bdad8b2fcdd4ca2b7 Mon Sep 17 00:00:00 2001 From: Madhu Date: Wed, 25 Dec 2024 11:30:40 +0530 Subject: [PATCH 4/6] final changes --- .../rag_database_routing.py | 141 ++++++++++++++---- 1 file changed, 112 insertions(+), 29 deletions(-) diff --git a/rag_tutorials/rag_database_routing/rag_database_routing.py b/rag_tutorials/rag_database_routing/rag_database_routing.py index 6582225..05e2f5b 100644 --- a/rag_tutorials/rag_database_routing/rag_database_routing.py +++ b/rag_tutorials/rag_database_routing/rag_database_routing.py @@ -1,12 +1,13 @@ import os -from typing import List, Dict, Literal +from typing import List, Dict, Any, Literal from dataclasses import dataclass import streamlit as st from langchain_core.documents import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import PyPDFLoader from langchain_community.vectorstores import Chroma -from langchain_openai import OpenAIEmbeddings, ChatOpenAI +from langchain_openai import OpenAIEmbeddings +from langchain_openai import ChatOpenAI import tempfile from phi.agent import Agent from phi.model.openai import OpenAIChat @@ -17,8 +18,10 @@ from langchain import hub from langgraph.prebuilt import create_react_agent from langchain_community.tools import DuckDuckGoSearchRun from langchain_core.language_models import BaseLanguageModel +from langchain.prompts import ChatPromptTemplate def init_session_state(): + """Initialize session state variables""" if 'openai_api_key' not in st.session_state: st.session_state.openai_api_key = "" if 'embeddings' not in st.session_state: @@ -40,6 +43,7 @@ class CollectionConfig: collection_name: str persist_directory: str +# Collection configurations COLLECTIONS: Dict[DatabaseType, CollectionConfig] = { "products": CollectionConfig( name="Product Information", @@ -62,25 +66,39 @@ COLLECTIONS: Dict[DatabaseType, CollectionConfig] = { } def initialize_models(): + """Initialize OpenAI models with API key""" if st.session_state.openai_api_key: os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-large") st.session_state.llm = ChatOpenAI(temperature=0) - for config in COLLECTIONS.values(): - os.makedirs(config.persist_directory, exist_ok=True) + # Ensure directories exist + for collection_config in COLLECTIONS.values(): + os.makedirs(collection_config.persist_directory, exist_ok=True) + # Initialize Chroma collections st.session_state.databases = { - db_type: Chroma( - collection_name=config.collection_name, + "products": Chroma( + collection_name=COLLECTIONS["products"].collection_name, embedding_function=st.session_state.embeddings, - persist_directory=config.persist_directory - ) for db_type, config in COLLECTIONS.items() + persist_directory=COLLECTIONS["products"].persist_directory + ), + "support": Chroma( + collection_name=COLLECTIONS["support"].collection_name, + embedding_function=st.session_state.embeddings, + persist_directory=COLLECTIONS["support"].persist_directory + ), + "finance": Chroma( + collection_name=COLLECTIONS["finance"].collection_name, + embedding_function=st.session_state.embeddings, + persist_directory=COLLECTIONS["finance"].persist_directory + ) } return True return False def process_document(file) -> List[Document]: + """Process uploaded PDF document""" try: with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: tmp_file.write(file.getvalue()) @@ -88,24 +106,37 @@ def process_document(file) -> List[Document]: loader = PyPDFLoader(tmp_path) documents = loader.load() + + # Clean up temporary file os.unlink(tmp_path) - text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=300) - return text_splitter.split_documents(documents) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200 + ) + texts = text_splitter.split_documents(documents) + + return texts except Exception as e: st.error(f"Error processing document: {e}") return [] def create_routing_agent() -> Agent: + """Creates a routing agent using phidata framework""" return Agent( - model=OpenAIChat(id="gpt-4o", api_key=st.session_state.openai_api_key), + model=OpenAIChat( + id="gpt-4o", + api_key=st.session_state.openai_api_key + ), tools=[], - description="You are a query routing expert. Your only job is to analyze questions and determine which database they should be routed to.", + description="""You are a query routing expert. Your only job is to analyze questions and determine which database they should be routed to. + You must respond with exactly one of these three options: 'products', 'support', or 'finance'. The user's question is: {question}""", instructions=[ - "1. For questions about products, return 'products'", - "2. For questions about support, return 'support'", - "3. For questions about finance, return 'finance'", - "4. Return ONLY the database name" + "Follow these rules strictly:", + "1. For questions about products, features, specifications, or item details, or product manuals → return 'products'", + "2. For questions about help, guidance, troubleshooting, or customer service, FAQ, or guides → return 'support'", + "3. For questions about costs, revenue, pricing, or financial data, or financial reports and investments → return 'finance'", + "4. Return ONLY the database name, no other text or explanation" ], markdown=False, show_tool_calls=False @@ -115,42 +146,72 @@ def route_query(question: str) -> DatabaseType: try: routing_agent = create_routing_agent() response = routing_agent.run(question) - db_type = response.content.strip().lower().translate(str.maketrans('', '', '`\'"')) + db_type = (response.content + .strip() + .lower() + .translate(str.maketrans('', '', '`\'"'))) # More elegant string cleaning + + # Validate database type if db_type not in COLLECTIONS: st.warning(f"Invalid database type: {db_type}, defaulting to products") return "products" st.info(f"Routing question to {db_type} database") return db_type + except Exception as e: st.error(f"Routing error: {str(e)}") return "products" def create_fallback_agent(chat_model: BaseLanguageModel): + """Create a LangGraph agent for web research.""" + def web_research(query: str) -> str: + """Web search with result formatting.""" try: search = DuckDuckGoSearchRun(num_results=5) - return search.run(query) + results = search.run(query) + return results except Exception as e: return f"Search failed: {str(e)}. Providing answer based on general knowledge." tools = [web_research] - return create_react_agent(model=chat_model, tools=tools, debug=False) + + agent = create_react_agent(model=chat_model, + tools=tools, + debug=False) + + return agent def query_database(db: Chroma, question: str) -> tuple[str, list]: try: - retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"k": 4, "score_threshold": 0.4}) + retriever = db.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 4, "score_threshold": 0.3} + ) + relevant_docs = retriever.get_relevant_documents(question) if relevant_docs: - retrieval_qa_prompt = hub.pull("langchain-ai/retrieval-qa-chat") + # Use simpler chain creation with hub prompt + retrieval_qa_prompt = ChatPromptTemplate.from_messages([ + ("system", """You are a helpful AI assistant that answers questions based on provided context. + Always be direct and concise in your responses. + If the context doesn't contain enough information to fully answer the question, acknowledge this limitation. + Base your answers strictly on the provided context and avoid making assumptions."""), + ("human", "Here is the context:\n{context}"), + ("human", "Question: {input}"), + ("assistant", "I'll help answer your question based on the context provided."), + ("human", "Please provide your answer:"), + ]) combine_docs_chain = create_stuff_documents_chain(st.session_state.llm, retrieval_qa_prompt) retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain) + response = retrieval_chain.invoke({"input": question}) return response['answer'], relevant_docs - return _handle_web_fallback(question) + except Exception as e: st.error(f"Error: {str(e)}") return "I encountered an error. Please try rephrasing your question.", [] @@ -161,7 +222,9 @@ def _handle_web_fallback(question: str) -> tuple[str, list]: with st.spinner('Researching...'): agent_input = { - "messages": [HumanMessage(content=f"Research and provide a detailed answer for: '{question}'")], + "messages": [ + HumanMessage(content=f"Research and provide a detailed answer for: '{question}'") + ], "is_last_step": False } @@ -170,17 +233,26 @@ def _handle_web_fallback(question: str) -> tuple[str, list]: if isinstance(response, dict) and "messages" in response: answer = response["messages"][-1].content return f"Web Search Result:\n{answer}", [] + except Exception: + # Fallback to general LLM response fallback_response = st.session_state.llm.invoke(question).content return f"Web search unavailable. General response: {fallback_response}", [] def main(): + """Main application function.""" st.set_page_config(page_title="RAG Agent with Database Routing", page_icon="📚") st.title("📚 RAG Agent with Database Routing") + # Sidebar for API key and database management with st.sidebar: st.header("Configuration") - api_key = st.text_input("Enter OpenAI API Key:", type="password", value=st.session_state.openai_api_key, key="api_key_input") + api_key = st.text_input( + "Enter OpenAI API Key:", + type="password", + value=st.session_state.openai_api_key, + key="api_key_input" + ) if api_key: st.session_state.openai_api_key = api_key @@ -194,15 +266,20 @@ def main(): st.stop() st.markdown("---") - + st.header("Document Upload") st.info("Upload documents to populate the databases. Each tab corresponds to a different database.") - tabs = st.tabs([config.name for config in COLLECTIONS.values()]) + tabs = st.tabs([collection_config.name for collection_config in COLLECTIONS.values()]) - for (collection_type, config), tab in zip(COLLECTIONS.items(), tabs): + for (collection_type, collection_config), tab in zip(COLLECTIONS.items(), tabs): with tab: - st.write(config.description) - uploaded_files = st.file_uploader(f"Upload PDF documents to {config.name}", type="pdf", key=f"upload_{collection_type}", accept_multiple_files=True) + st.write(collection_config.description) + uploaded_files = st.file_uploader( + f"Upload PDF documents to {collection_config.name}", + type="pdf", + key=f"upload_{collection_type}", + accept_multiple_files=True + ) if uploaded_files: with st.spinner('Processing documents...'): @@ -216,15 +293,21 @@ def main(): db.add_documents(all_texts) st.success("Documents processed and added to the database!") + # Query section st.header("Ask Questions") st.info("Enter your question below to find answers from the relevant database.") question = st.text_input("Enter your question:") if question: with st.spinner('Finding answer...'): + # Route the question collection_type = route_query(question) db = st.session_state.databases[collection_type] + + # Display routing information st.info(f"Routing question to: {COLLECTIONS[collection_type].name}") + + # Get and display answer answer, relevant_docs = query_database(db, question) st.write("### Answer") st.write(answer) From f9c755dd32a40b8c95c4aab11d41b3d28705eb73 Mon Sep 17 00:00:00 2001 From: Madhu Date: Wed, 25 Dec 2024 13:13:12 +0530 Subject: [PATCH 5/6] added qdrant as db --- .../rag_database_routing.py | 200 ++++++++++++------ .../rag_database_routing/requirements.txt | 2 +- 2 files changed, 137 insertions(+), 65 deletions(-) diff --git a/rag_tutorials/rag_database_routing/rag_database_routing.py b/rag_tutorials/rag_database_routing/rag_database_routing.py index 05e2f5b..4fb23d2 100644 --- a/rag_tutorials/rag_database_routing/rag_database_routing.py +++ b/rag_tutorials/rag_database_routing/rag_database_routing.py @@ -1,11 +1,11 @@ import os -from typing import List, Dict, Any, Literal +from typing import List, Dict, Any, Literal, Optional from dataclasses import dataclass import streamlit as st from langchain_core.documents import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import PyPDFLoader -from langchain_community.vectorstores import Chroma +from langchain_community.vectorstores import Qdrant from langchain_openai import OpenAIEmbeddings from langchain_openai import ChatOpenAI import tempfile @@ -19,11 +19,17 @@ from langgraph.prebuilt import create_react_agent from langchain_community.tools import DuckDuckGoSearchRun from langchain_core.language_models import BaseLanguageModel from langchain.prompts import ChatPromptTemplate +from qdrant_client import QdrantClient +from qdrant_client.models import Distance, VectorParams def init_session_state(): """Initialize session state variables""" if 'openai_api_key' not in st.session_state: st.session_state.openai_api_key = "" + if 'qdrant_url' not in st.session_state: + st.session_state.qdrant_url = "" + if 'qdrant_api_key' not in st.session_state: + st.session_state.qdrant_api_key = "" if 'embeddings' not in st.session_state: st.session_state.embeddings = None if 'llm' not in st.session_state: @@ -40,61 +46,68 @@ PERSIST_DIRECTORY = "db_storage" class CollectionConfig: name: str description: str - collection_name: str - persist_directory: str + collection_name: str # This will be used as Qdrant collection name # Collection configurations COLLECTIONS: Dict[DatabaseType, CollectionConfig] = { "products": CollectionConfig( name="Product Information", description="Product details, specifications, and features", - collection_name="products_collection", - persist_directory=f"{PERSIST_DIRECTORY}/products" + collection_name="products_collection" ), "support": CollectionConfig( name="Customer Support & FAQ", description="Customer support information, frequently asked questions, and guides", - collection_name="support_collection", - persist_directory=f"{PERSIST_DIRECTORY}/support" + collection_name="support_collection" ), "finance": CollectionConfig( name="Financial Information", description="Financial data, revenue, costs, and liabilities", - collection_name="finance_collection", - persist_directory=f"{PERSIST_DIRECTORY}/finance" + collection_name="finance_collection" ) } def initialize_models(): - """Initialize OpenAI models with API key""" - if st.session_state.openai_api_key: + """Initialize OpenAI models and Qdrant client""" + if (st.session_state.openai_api_key and + st.session_state.qdrant_url and + st.session_state.qdrant_api_key): + os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key - st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-large") + st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-small") st.session_state.llm = ChatOpenAI(temperature=0) - # Ensure directories exist - for collection_config in COLLECTIONS.values(): - os.makedirs(collection_config.persist_directory, exist_ok=True) - - # Initialize Chroma collections - st.session_state.databases = { - "products": Chroma( - collection_name=COLLECTIONS["products"].collection_name, - embedding_function=st.session_state.embeddings, - persist_directory=COLLECTIONS["products"].persist_directory - ), - "support": Chroma( - collection_name=COLLECTIONS["support"].collection_name, - embedding_function=st.session_state.embeddings, - persist_directory=COLLECTIONS["support"].persist_directory - ), - "finance": Chroma( - collection_name=COLLECTIONS["finance"].collection_name, - embedding_function=st.session_state.embeddings, - persist_directory=COLLECTIONS["finance"].persist_directory + try: + # Initialize Qdrant client with session state credentials + client = QdrantClient( + url=st.session_state.qdrant_url, + api_key=st.session_state.qdrant_api_key ) - } - return True + + # Test connection + client.get_collections() + vector_size = 1536 + st.session_state.databases = {} + for db_type, config in COLLECTIONS.items(): + try: + client.get_collection(config.collection_name) + except Exception: + # Create collection if it doesn't exist + client.create_collection( + collection_name=config.collection_name, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE) + ) + + st.session_state.databases[db_type] = Qdrant( + client=client, + collection_name=config.collection_name, + embeddings=st.session_state.embeddings + ) + + return True + except Exception as e: + st.error(f"Failed to connect to Qdrant: {str(e)}") + return False return False def process_document(file) -> List[Document]: @@ -136,33 +149,62 @@ def create_routing_agent() -> Agent: "1. For questions about products, features, specifications, or item details, or product manuals → return 'products'", "2. For questions about help, guidance, troubleshooting, or customer service, FAQ, or guides → return 'support'", "3. For questions about costs, revenue, pricing, or financial data, or financial reports and investments → return 'finance'", - "4. Return ONLY the database name, no other text or explanation" + "4. Return ONLY the database name, no other text or explanation", + "5. If you're not confident about the routing, return an empty response" ], markdown=False, show_tool_calls=False ) -def route_query(question: str) -> DatabaseType: +def route_query(question: str) -> Optional[DatabaseType]: + """Route query by searching all databases and comparing relevance scores. + Returns None if no suitable database is found.""" try: + best_score = -1 + best_db_type = None + all_scores = {} # Store all scores for debugging + + # Search each database and compare relevance scores + for db_type, db in st.session_state.databases.items(): + results = db.similarity_search_with_score( + question, + k=3 + ) + + if results: + avg_score = sum(score for _, score in results) / len(results) + all_scores[db_type] = avg_score + + if avg_score > best_score: + best_score = avg_score + best_db_type = db_type + + confidence_threshold = 0.5 + if best_score >= confidence_threshold and best_db_type: + st.success(f"Using vector similarity routing: {best_db_type} (confidence: {best_score:.3f})") + return best_db_type + + st.warning(f"Low confidence scores (below {confidence_threshold}), falling back to LLM routing") + + # Fallback to LLM routing routing_agent = create_routing_agent() response = routing_agent.run(question) db_type = (response.content .strip() .lower() - .translate(str.maketrans('', '', '`\'"'))) # More elegant string cleaning + .translate(str.maketrans('', '', '`\'"'))) - # Validate database type - if db_type not in COLLECTIONS: - st.warning(f"Invalid database type: {db_type}, defaulting to products") - return "products" - - st.info(f"Routing question to {db_type} database") - return db_type + if db_type in COLLECTIONS: + st.success(f"Using LLM routing decision: {db_type}") + return db_type + + st.warning("No suitable database found, will use web search fallback") + return None except Exception as e: st.error(f"Routing error: {str(e)}") - return "products" + return None def create_fallback_agent(chat_model: BaseLanguageModel): """Create a LangGraph agent for web research.""" @@ -184,11 +226,12 @@ def create_fallback_agent(chat_model: BaseLanguageModel): return agent -def query_database(db: Chroma, question: str) -> tuple[str, list]: +def query_database(db: Qdrant, question: str) -> tuple[str, list]: + """Query the database and return answer and relevant documents""" try: retriever = db.as_retriever( - search_type="similarity_score_threshold", - search_kwargs={"k": 4, "score_threshold": 0.3} + search_type="similarity", + search_kwargs={"k": 4} ) relevant_docs = retriever.get_relevant_documents(question) @@ -210,7 +253,8 @@ def query_database(db: Chroma, question: str) -> tuple[str, list]: response = retrieval_chain.invoke({"input": question}) return response['answer'], relevant_docs - return _handle_web_fallback(question) + + raise ValueError("No relevant documents found in database") except Exception as e: st.error(f"Error: {str(e)}") @@ -244,9 +288,11 @@ def main(): st.set_page_config(page_title="RAG Agent with Database Routing", page_icon="📚") st.title("📚 RAG Agent with Database Routing") - # Sidebar for API key and database management + # Sidebar for API keys and configuration with st.sidebar: st.header("Configuration") + + # OpenAI API Key api_key = st.text_input( "Enter OpenAI API Key:", type="password", @@ -254,15 +300,37 @@ def main(): key="api_key_input" ) + # Qdrant Configuration + qdrant_url = st.text_input( + "Enter Qdrant URL:", + value=st.session_state.qdrant_url, + help="Example: https://your-cluster.qdrant.tech" + ) + + qdrant_api_key = st.text_input( + "Enter Qdrant API Key:", + type="password", + value=st.session_state.qdrant_api_key + ) + + # Update session state if api_key: st.session_state.openai_api_key = api_key + if qdrant_url: + st.session_state.qdrant_url = qdrant_url + if qdrant_api_key: + st.session_state.qdrant_api_key = qdrant_api_key + + # Initialize models if all credentials are provided + if (st.session_state.openai_api_key and + st.session_state.qdrant_url and + st.session_state.qdrant_api_key): if initialize_models(): - st.success("API Key set successfully!") + st.success("Connected to OpenAI and Qdrant successfully!") else: - st.error("Invalid API Key") - - if not st.session_state.openai_api_key: - st.warning("Please enter your OpenAI API key to continue") + st.error("Failed to initialize. Please check your credentials.") + else: + st.warning("Please enter all required credentials to continue") st.stop() st.markdown("---") @@ -302,15 +370,19 @@ def main(): with st.spinner('Finding answer...'): # Route the question collection_type = route_query(question) - db = st.session_state.databases[collection_type] - # Display routing information - st.info(f"Routing question to: {COLLECTIONS[collection_type].name}") - - # Get and display answer - answer, relevant_docs = query_database(db, question) - st.write("### Answer") - st.write(answer) + if collection_type is None: + # Use web search fallback directly + answer, relevant_docs = _handle_web_fallback(question) + st.write("### Answer (from web search)") + st.write(answer) + else: + # Display routing information and query the database + st.info(f"Routing question to: {COLLECTIONS[collection_type].name}") + db = st.session_state.databases[collection_type] + answer, relevant_docs = query_database(db, question) + st.write("### Answer") + st.write(answer) if __name__ == "__main__": main() diff --git a/rag_tutorials/rag_database_routing/requirements.txt b/rag_tutorials/rag_database_routing/requirements.txt index c0259ec..0c69e77 100644 --- a/rag_tutorials/rag_database_routing/requirements.txt +++ b/rag_tutorials/rag_database_routing/requirements.txt @@ -1,7 +1,7 @@ langchain==0.3.12 langchain-community==0.3.12 langchain-core==0.3.28 -chromadb==0.5.20 +qdrant-client==1.12.1 streamlit>=1.29.0 pypdf>=4.0.0 sentence-transformers>=2.2.2 From ba7478407a98d28178a95ef1da22429d3bf020ca Mon Sep 17 00:00:00 2001 From: Madhu Date: Thu, 26 Dec 2024 09:25:24 +0530 Subject: [PATCH 6/6] resolves a few issues from other folders --- ai_agent_tutorials/ai_recruitment_agent_team/README.md | 4 ++-- .../multimodal_design_agent_team/requirements.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ai_agent_tutorials/ai_recruitment_agent_team/README.md b/ai_agent_tutorials/ai_recruitment_agent_team/README.md index 0d99d09..748cf02 100644 --- a/ai_agent_tutorials/ai_recruitment_agent_team/README.md +++ b/ai_agent_tutorials/ai_recruitment_agent_team/README.md @@ -37,8 +37,8 @@ An Agentic recruitment system built on phidata and Streamlitthat automates the t 1. **Setup Environment** ```bash # Clone the repository - git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git - cd ai_agent_tutorials/ai_recruitment_agent_team + git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git + cd ai_agent_tutorials/ai_recruitment_agent_team # Install dependencies pip install -r requirements.txt diff --git a/ai_agent_tutorials/multimodal_design_agent_team/requirements.txt b/ai_agent_tutorials/multimodal_design_agent_team/requirements.txt index 13a4054..6ec9d03 100644 --- a/ai_agent_tutorials/multimodal_design_agent_team/requirements.txt +++ b/ai_agent_tutorials/multimodal_design_agent_team/requirements.txt @@ -1,5 +1,5 @@ google-generativeai==0.8.3 -streamlit==1.30.0 +streamlit==1.41.1 phidata==2.7.2 Pillow==11.0.0 duckduckgo-search==6.3.7