提示:本文主要介绍公司最近落地的nl2sql项目,并整理成文档记录一下
在做数据分析时,流程大致是业务侧提数据分析需求-》数据来源评估-》数据开发-》数据可视化应用-》需求交付。随着生成式人工智能AGI的发展,可以充分利用人工智能的理解能力,最大限度的提高数据分析自动化水平,从而提高交付效率。其中nl2sql,就是典型的数据分析中的大模型应用。具体实施流程如下:
已经训练的模型功能虽然强大,但是由于对公司内部的数据往往一无所知,从而不能发挥出应有的性能。因此需要给静态大模型增加公司内部数据知识。常见的方式有:
在公司项目中,采用的是RAG方式,将公司内部数据做成了知识库,提供问答过程中模型所需的数据来源。
对于RAG系统来说,知识库的高效构建至关重要,一个好的知识库应该是稳定、可靠、扩展性强的。经过调研,知识库的构建方式往往有以下几种:
得益于大模型的理解能力,基于知识图谱构建已经可以交给大模型做,摆脱了传统人工构建的费时费力的问题,越来越多的项目采用知识图谱的方式构建知识库。graphiti就是一个基于大模型构建知识图谱的开源工具。在公司项目中,采用二者结合的方式,先基于索引检索中心节点,再基于graphiti构建的知识图谱进行知识检索。 基于graphiti构建数据库表结构知识图谱示例如下:
from graphiti_core import Graphiti
from graphiti_core.utils.maintenance.node_operations import (
extract_attributes_from_nodes,
extract_nodes,
)
from graphiti_core.utils.maintenance.edge_operations import (
build_episodic_edges,
resolve_extracted_edges,
)
from graphiti_core.utils.bulk_utils import (
add_nodes_and_edges_bulk,
)
from graphiti_core.nodes import EpisodeType, EpisodicNode, EntityNode
from graphiti_core.edges import EntityEdge
async def add_table_column_nodes (
client: Graphiti,
center_node_uuid: str,
schema_name: str,
table_name: str,
table_comment: str,
columns: list,
):
start = datetime.now(timezone.utc)
now = datetime.now(timezone.utc)
previous_episodes = []
episode:EpisodicNode = await EpisodicNode.get_by_uuid(client.driver, center_node_uuid)
# Extract entities as nodes
extracted_nodes = []
for column in columns:
extracted_nodes.append(EntityNode(
name = schema_name + "." + table_name + "." + column["column_name"],
group_id = episode.group_id,
labels = ["Column"],
created_at = now,
attributes = {
"schema_name": schema_name,
"table_name": table_name,
"column_name": column["column_name"],
"column_comment": column["column_comment"],
"data_type": column["data_type"],
}
))
table_node = EntityNode(
name = schema_name + "." + table_name,
group_id = episode.group_id,
labels = ["Table"],
created_at = now,
attributes = {
"schema_name": schema_name,
"table_name": table_name,
"table_comment": table_comment,
}
)
nodes = [table_node] + extracted_nodes
episodic_edges = build_episodic_edges([table_node], episode, now)
edges = []
for extracted_node in extracted_nodes:
edges.append(EntityEdge(
name = "has_column",
fact = table_node.name + " has column " + extracted_node.name,
source_node_uuid = table_node.uuid,
target_node_uuid = extracted_node.uuid,
group_id = extracted_node.group_id,
created_at = extracted_node.created_at
))
(resolved_edges, invalidated_edges) = await resolve_extracted_edges(
client.clients,
edges,
episode,
nodes,
{},
({('Entity', 'Entity'): []}),
)
entity_edges = resolved_edges
hydrated_nodes = await extract_attributes_from_nodes(
client.clients, nodes, episode, previous_episodes, None
)
await add_nodes_and_edges_bulk(
client.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, client.embedder
)
end = datetime.now(timezone.utc)
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')有了数据库表结构的知识图谱,结合大模型的能力,就可以搭建一个nl2sql的agent。具体实现示例如下:
from logger import logger
from rag_graphiti_client import client,llm
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode
from rag_graphiti_mcp_tools import get_table_columns
tools = [get_table_columns]
tool_node = ToolNode(tools)
llm_with_tools = llm.bind_tools(tools)
from rag_graphiti_chatbot import Chatbot,State
chatbot = Chatbot(
llm=llm_with_tools,
graphiti=client
)
graph_builder = StateGraph(State)
memory = MemorySaver()
# Define the function that determines whether to continue or not
async def should_continue(state, config):
# logger.info(f"should_continue state: {state}")
messages = state['messages']
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return 'end'
# Otherwise if there is, we continue
else:
return 'continue'
graph_builder.add_node('agent', chatbot.generate_chatbot_response)
graph_builder.add_node('tools', tool_node)
graph_builder.add_edge(START, 'agent')
graph_builder.add_conditional_edges('agent', should_continue, {'continue': 'tools', 'end': END})
graph_builder.add_edge('tools', 'agent')
graph = graph_builder.compile(checkpointer=memory)
from rag_graphiti_user import query_user,create_user
user_name = "tcq"
user_node_uuid = await query_user(user_name)
if user_node_uuid is None:
user_node_uuid = await create_user(user_name)
user_state = State(
user_name=user_name,
user_node_uuid=user_node_uuid
)
from rag_graphiti_agent import Agent,AgentRun
agent = Agent(graph)
agentRun:AgentRun = agent.run(user_state=user_state)为了提高nl2sql的准确性,使用了基于模板的动态化提示词方式。通过prompt模板减少了话语结构的较大变化,从而降低模型理解的偏差,显著提高大模型生成结果的一致性,同时可以利用模型缓存提高模型性能。通过动态化提示词,可以对用户的输入进行改写,通过新增对话上下文信息,可以让模型更好的理解输入提示词,从而提高模型准确性。具体示例如下:
async def generate_chatbot_response(self, state: State):
facts_string = None
if len(state['messages']) > 0:
last_message = state['messages'][-1]
graphiti_query = f'{"TableColumnQueryBot" if isinstance(last_message, AIMessage) else state["user_name"]}: {last_message.content}'
# search graphiti using Jess's node uuid as the center node
# graph edges (facts) further from the Jess node will be ranked lower
edge_results = await asyncio.create_task(self.graphiti.search(
graphiti_query, center_node_uuid=state['user_node_uuid'], num_results=5
))
facts_string = edges_to_facts_string(edge_results)
system_message = SystemMessage(
content=f"""You are a skillfull table column relationship manager. Review information about the user and their prior conversation below and respond accordingly.
Keep responses short and concise. And remember, always be helpful!
Things you'll need to know about the user in order to generate a helpful response:
- need query table
Ensure that you ask the user for the above if you don't already know.
Facts about the user and their conversation:
{facts_string or 'No facts about the user and their conversation'}"""
)
messages = [system_message] + state['messages']
# ...经过上线运行,nl2sql成功的达到了预期的效果,目前在团队内部已经大规模使用,实现demo。下一步就是封装出nl2sql的api,结合sql执行引擎实现基于对话的数据分析agent。