基于OpenAI大模型开发——Function calling应用text2SQL
前言
本文所演示的案例只是帮助大家更加深刻的了解Function calling能力,切记不可在生产环境进行操作验证!!
切记不可在生产环境进行操作验证!!
切记不可在生产环境进行操作验证!!
模型在生成正确的 SQL 方面并不完全可靠,除了大模型本身的能力以外,如果投产使用还需要很多完善的点,比如提供数据字典、SQL校验、多数据库连接等。
但是text2sql是一个比较好的基于大模型落地的应用场景,实现上相对简单,大家有需求可以自行研究研究,后续我也会带来一个比较完善的text2sql的案例。
案例 Function calling 实现text2SQL
这个例子将演示如何执行其输入由模型生成的函数,并使用它来实现可以回答我们有关数据库的问题的代理。为简单起见,我们将使用 Chinook 示例数据库。
import sqlite3
conn = sqlite3.connect("data/Chinook.db")
print("Opened database successfully")
Opened database successfully
def get_table_names(conn):
"""Return a list of table names."""
table_names = []
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
for table in tables.fetchall():
table_names.append(table[0])
return table_names
def get_column_names(conn, table_name):
"""Return a list of column names."""
column_names = []
columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
for col in columns:
column_names.append(col[1])
return column_names
def get_database_info(conn):
"""Return a list of dicts containing the table name and columns for each table in the database."""
table_dicts = []
for table_name in get_table_names(conn):
columns_names = get_column_names(conn, table_name)
table_dicts.append({"table_name": table_name, "column_names": columns_names})
return table_dicts
database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
[
f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
for table in database_schema_dict
]
)
print(database_schema_string)
Table: Album
Columns: AlbumId, Title, ArtistId
Table: Artist
Columns: ArtistId, Name
Table: Customer
Columns: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId
Table: Employee
Columns: EmployeeId, LastName, FirstName, Title, ReportsTo, BirthDate, HireDate, Address, City, State, Country, PostalCode, Phone, Fax, Email
Table: Genre
Columns: GenreId, Name
Table: Invoice
Columns: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity, BillingState, BillingCountry, BillingPostalCode, Total
Table: InvoiceLine
Columns: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity
Table: MediaType
Columns: MediaTypeId, Name
Table: Playlist
Columns: PlaylistId, Name
Table: PlaylistTrack
Columns: PlaylistId, TrackId
Table: Track
Columns: TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, UnitPrice
将数据库表、列字段信息加入到工具描述中
tools = [
{
"type": "function",
"function": {
"name": "ask_database",
"description": "Use this function to answer user questions about music. Input should be a fully formed SQL query.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": f"""
SQL query extracting info to answer the user's question.
SQL should be written using this database schema:
{database_schema_string}
The query should be returned in plain text, not in JSON.
""",
}
},
"required": ["query"],
},
}
}
]
定义一个查询数据库的函数
def ask_database(conn, query):
"""Function to query SQLite database with a provided SQL query."""
try:
results = str(conn.execute(query).fetchall())
except Exception as e:
results = f"query failed with error: {e}"
return results
使用大模型生成sql并执行返回结果的步骤:
- 向模型提示可能导致模型选择要使用的工具的内容。如果选择,函数名称和参数将包含在响应中。
- 以编程方式检查模型是否想要调用函数。如果为真,请继续执行步骤 3。
- 从响应中提取函数名称和参数,带参数调用函数。将结果附加到消息中。
- 使用消息列表调用大模型获取响应。
import os
from dotenv import load_dotenv
from openai import OpenAI
# 秘钥
load_dotenv()
WILDCARD_API_KEY = os.getenv('WILDCARD_API_KEY')
WILDCARD_API = os.getenv('WILDCARD_API')
client = OpenAI(api_key=WILDCARD_API_KEY, base_url=WILDCARD_API)
messages = [
{"role": "system",
"content": "你是一名全能小助手,无所不能,可以执行各种函数功能,如加法计算、获取天气、生成SQL查询数据库等。在需要时调用适当的函数来处理。对于回答不作任何解释"
},
{
"role":"user",
# "content": "What is the name of the album with the most tracks?"
"content": "哪个专辑里面的曲目最多?"
}
]
response = client.chat.completions.create(
model='gpt-3.5-turbo-0613',
messages=messages,
tools=tools,
tool_choice="auto"
)
# Append the message to messages list
response_message = response.choices[0].message
# messages.append(response_message)
print(response_message) # response_message包含了大模型生成的sql
ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_GHZ0dK7cOoLoe6uE9Irrnl00', function=Function(arguments='{"query":"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1;"}', name='ask_database'), type='function')])
tool_calls = response_message.tool_calls
print(tool_calls)
[ChatCompletionMessageToolCall(id='call_GHZ0dK7cOoLoe6uE9Irrnl00', function=Function(arguments='{"query":"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1;"}', name='ask_database'), type='function')]
import json
if tool_calls:
messages = messages[0:2] # 多次执行处理messages
messages.append(response_message)
for tool_call in tool_calls:
tool_call_id = tool_call.id
function_name = tool_call.function.name
# print(tool_call.function.arguments)
function_args = json.loads(tool_call.function.arguments)['query']
if function_name == "ask_database":
results = ask_database(conn, function_args)
# print(messages)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": str(results)
},
)
print(messages)
fin_response = client.chat.completions.create(
model="gpt-3.5-turbo-0613",
messages=messages
)
print("----")
print(fin_response.choices[0].message.content)
[{'role': 'system', 'content': '你是一名全能小助手,无所不能,可以执行各种函数功能,如加法计算、获取天气、生成SQL查询数据库等。在需要时调用适当的函数来处理。对于回答不作任何解释'}, {'role': 'user', 'content': '哪个专辑里面的曲目最多?'}, ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_GHZ0dK7cOoLoe6uE9Irrnl00', function=Function(arguments='{"query":"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1;"}', name='ask_database'), type='function')]), {'role': 'tool', 'tool_call_id': 'call_GHZ0dK7cOoLoe6uE9Irrnl00', 'name': 'ask_database', 'content': "[('Greatest Hits', 57)]"}]
----
“Greatest Hits”专辑拥有的曲目最多,共有57首。
在最后的打印结果中,我们可以看到调用的工具、生成的SQL、SQL执行的结果,这就是Function calling提供的能力。
Function calling由OpenAI首次提出,而且Function calling的能力是由对模型微调产生的,后续其他厂商一系列模型也都跟进了这个能力,像阿里的Qwen模型,专门对外部函数调用做了微调和优化。