67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
|
|
import streamlit as st
|
||
|
|
from langchain.chains import create_retrieval_chain
|
||
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||
|
|
from langchain_core.prompts import ChatPromptTemplate
|
||
|
|
from src.rag_core import get_retriever, get_llm
|
||
|
|
|
||
|
|
st.set_page_config(page_title="RAG Chat", layout="wide")
|
||
|
|
st.title("Paperless-NGX RAG Assistant")
|
||
|
|
|
||
|
|
retriever, vectorstore, _ = get_retriever()
|
||
|
|
llm = get_llm()
|
||
|
|
|
||
|
|
# Prompt Template
|
||
|
|
system_prompt = (
|
||
|
|
"Du bist ein hilfreicher Assistent. Nutze den folgenden Kontext, um die Frage zu beantworten. "
|
||
|
|
"Wenn du die Antwort nicht weißt, sage einfach, dass du sie nicht weißt.\n\n"
|
||
|
|
"{context}"
|
||
|
|
)
|
||
|
|
prompt = ChatPromptTemplate.from_messages([
|
||
|
|
("system", system_prompt),
|
||
|
|
("human", "{input}"),
|
||
|
|
])
|
||
|
|
|
||
|
|
question_answer_chain = create_stuff_documents_chain(llm, prompt)
|
||
|
|
|
||
|
|
# Optionale Filter-UI
|
||
|
|
st.sidebar.header("Filter")
|
||
|
|
filter_id = st.sidebar.text_input("Nur in Document ID suchen (optional):")
|
||
|
|
|
||
|
|
if "messages" not in st.session_state:
|
||
|
|
st.session_state.messages = []
|
||
|
|
|
||
|
|
for msg in st.session_state.messages:
|
||
|
|
with st.chat_message(msg["role"]):
|
||
|
|
st.markdown(msg["content"])
|
||
|
|
|
||
|
|
if prompt_input := st.chat_input("Stelle eine Frage zu deinen Dokumenten..."):
|
||
|
|
st.session_state.messages.append({"role": "user", "content": prompt_input})
|
||
|
|
with st.chat_message("user"):
|
||
|
|
st.markdown(prompt_input)
|
||
|
|
|
||
|
|
with st.chat_message("assistant"):
|
||
|
|
# Dynamischer Retriever mit Metadaten-Filter
|
||
|
|
search_kwargs = {"k": 3}
|
||
|
|
if filter_id:
|
||
|
|
search_kwargs["filter"] = {"paperless_id": int(filter_id)}
|
||
|
|
|
||
|
|
# Override search_kwargs temporär
|
||
|
|
retriever.search_kwargs = search_kwargs
|
||
|
|
|
||
|
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
||
|
|
|
||
|
|
with st.spinner("Denke nach..."):
|
||
|
|
response = rag_chain.invoke({"input": prompt_input})
|
||
|
|
answer = response["answer"]
|
||
|
|
sources = response.get("context", [])
|
||
|
|
|
||
|
|
st.markdown(answer)
|
||
|
|
|
||
|
|
if sources:
|
||
|
|
st.write("---")
|
||
|
|
st.write("**Quellen:**")
|
||
|
|
unique_sources = {doc.metadata.get("source") for doc in sources if doc.metadata.get("source")}
|
||
|
|
for s in unique_sources:
|
||
|
|
st.write(f"- {s}")
|
||
|
|
|
||
|
|
st.session_state.messages.append({"role": "assistant", "content": answer})
|