九、基于OpenAI大模型开发——Function calling应用text2SQL

基于OpenAI大模型开发——Function calling应用text2SQL


前言

本文所演示的案例只是帮助大家更加深刻的了解Function calling能力,切记不可在生产环境进行操作验证!!
切记不可在生产环境进行操作验证!!
切记不可在生产环境进行操作验证!!
模型在生成正确的 SQL 方面并不完全可靠,除了大模型本身的能力以外,如果投产使用还需要很多完善的点,比如提供数据字典、SQL校验、多数据库连接等。
但是text2sql是一个比较好的基于大模型落地的应用场景,实现上相对简单,大家有需求可以自行研究研究,后续我也会带来一个比较完善的text2sql的案例。

案例 Function calling 实现text2SQL

这个例子将演示如何执行其输入由模型生成的函数,并使用它来实现可以回答我们有关数据库的问题的代理。为简单起见,我们将使用 Chinook 示例数据库。

import sqlite3
conn = sqlite3.connect("data/Chinook.db")
print("Opened database successfully")
Opened database successfully
def get_table_names(conn):
    """Return a list of table names."""
    table_names = []
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names


def get_column_names(conn, table_name):
    """Return a list of column names."""
    column_names = []
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    for col in columns:
        column_names.append(col[1])
    return column_names


def get_database_info(conn):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    for table_name in get_table_names(conn):
        columns_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": columns_names})
    return table_dicts
database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
        for table in database_schema_dict
    ]
)
print(database_schema_string)
Table: Album
Columns: AlbumId, Title, ArtistId
Table: Artist
Columns: ArtistId, Name
Table: Customer
Columns: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId
Table: Employee
Columns: EmployeeId, LastName, FirstName, Title, ReportsTo, BirthDate, HireDate, Address, City, State, Country, PostalCode, Phone, Fax, Email
Table: Genre
Columns: GenreId, Name
Table: Invoice
Columns: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity, BillingState, BillingCountry, BillingPostalCode, Total
Table: InvoiceLine
Columns: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity
Table: MediaType
Columns: MediaTypeId, Name
Table: Playlist
Columns: PlaylistId, Name
Table: PlaylistTrack
Columns: PlaylistId, TrackId
Table: Track
Columns: TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, UnitPrice

将数据库表、列字段信息加入到工具描述中

tools = [
    {
        "type": "function",
        "function": {
            "name": "ask_database",
            "description": "Use this function to answer user questions about music. Input should be a fully formed SQL query.",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": f"""
                                SQL query extracting info to answer the user's question.
                                SQL should be written using this database schema:
                                {database_schema_string}
                                The query should be returned in plain text, not in JSON.
                                """,
                    }
                },
                "required": ["query"],
            },
        }
    }
]

定义一个查询数据库的函数

def ask_database(conn, query):
    """Function to query SQLite database with a provided SQL query."""
    try:
        results = str(conn.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with error: {e}"
    return results

使用大模型生成sql并执行返回结果的步骤:

  1. 向模型提示可能导致模型选择要使用的工具的内容。如果选择,函数名称和参数将包含在响应中。
  2. 以编程方式检查模型是否想要调用函数。如果为真,请继续执行步骤 3。
  3. 从响应中提取函数名称和参数,带参数调用函数。将结果附加到消息中。
  4. 使用消息列表调用大模型获取响应。
import os
from dotenv import load_dotenv
from openai import OpenAI
# 秘钥
load_dotenv()
WILDCARD_API_KEY = os.getenv('WILDCARD_API_KEY')
WILDCARD_API = os.getenv('WILDCARD_API')
client = OpenAI(api_key=WILDCARD_API_KEY, base_url=WILDCARD_API)
messages = [
    {"role": "system", 
     "content": "你是一名全能小助手,无所不能,可以执行各种函数功能,如加法计算、获取天气、生成SQL查询数据库等。在需要时调用适当的函数来处理。对于回答不作任何解释"
     },
    {
    "role":"user", 
    # "content": "What is the name of the album with the most tracks?"
    "content": "哪个专辑里面的曲目最多?"
    }
]

response = client.chat.completions.create(
    model='gpt-3.5-turbo-0613', 
    messages=messages, 
    tools=tools, 
    tool_choice="auto"
)

# Append the message to messages list
response_message = response.choices[0].message 
# messages.append(response_message)

print(response_message) # response_message包含了大模型生成的sql
ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_GHZ0dK7cOoLoe6uE9Irrnl00', function=Function(arguments='{"query":"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1;"}', name='ask_database'), type='function')])
tool_calls = response_message.tool_calls
print(tool_calls)
[ChatCompletionMessageToolCall(id='call_GHZ0dK7cOoLoe6uE9Irrnl00', function=Function(arguments='{"query":"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1;"}', name='ask_database'), type='function')]
import json
if tool_calls:
    messages = messages[0:2] # 多次执行处理messages
    messages.append(response_message)
    for tool_call in tool_calls:
        tool_call_id = tool_call.id
        function_name = tool_call.function.name
        # print(tool_call.function.arguments)
        function_args = json.loads(tool_call.function.arguments)['query']
        
        if function_name == "ask_database":
            results = ask_database(conn, function_args)
            # print(messages)
            messages.append(
                {
                    "role": "tool",
                    "tool_call_id": tool_call_id,
                    "name": function_name,
                    "content": str(results)
                },
            )
            
            print(messages)
            fin_response = client.chat.completions.create(
                model="gpt-3.5-turbo-0613",
                messages=messages
            )
            print("----")
            print(fin_response.choices[0].message.content)
[{'role': 'system', 'content': '你是一名全能小助手,无所不能,可以执行各种函数功能,如加法计算、获取天气、生成SQL查询数据库等。在需要时调用适当的函数来处理。对于回答不作任何解释'}, {'role': 'user', 'content': '哪个专辑里面的曲目最多?'}, ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_GHZ0dK7cOoLoe6uE9Irrnl00', function=Function(arguments='{"query":"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1;"}', name='ask_database'), type='function')]), {'role': 'tool', 'tool_call_id': 'call_GHZ0dK7cOoLoe6uE9Irrnl00', 'name': 'ask_database', 'content': "[('Greatest Hits', 57)]"}]
----
“Greatest Hits”专辑拥有的曲目最多,共有57首。

在最后的打印结果中,我们可以看到调用的工具、生成的SQL、SQL执行的结果,这就是Function calling提供的能力。
Function calling由OpenAI首次提出,而且Function calling的能力是由对模型微调产生的,后续其他厂商一系列模型也都跟进了这个能力,像阿里的Qwen模型,专门对外部函数调用做了微调和优化。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

偷学技术的梁胖胖yo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值