This commit is contained in:
2023-07-16 14:37:12 +02:00
parent a4e25ba442
commit 16ef0d0ebd
3 changed files with 285 additions and 30 deletions

2
.gitignore vendored
View File

@ -3,4 +3,4 @@
!/LICENSE
!/build.sbt
!/README.md
openai-scala-client.conf
openai-scala-client.conf

236
README.md
View File

@ -1,2 +1,238 @@
<!-- LTeX: language=en-US -->
# chat-sql
chat-sql is a cli, which converts queries written in natural language to SQL
queries via ChatGPT. It connects a PostgreSQL database, to execute the
queries directly. The schema of the all tables of the database is sent to ChatGPT
in order to provide better results.
## Setup
Make sure you have a Scala 3 compiler and sbt installed. Then you can clone the
git repository.
To use the ChatGPT api you have to create an API key on the
[OpenAI Platform](https://platform.openai.com/account/api-keys). You also need
to buy tokens or use the tokens you get for free the first three months after
you created your OpenAI account.
To configure chat-sql, create a config file `chatSql.conf` in the git directory
with the following content:
```
database = "jdbc:postgresql://localhost:5432/database_name?user=username&password=password"
apiKey = "your api key"
```
Replace the values accordingly.
## Usage
Start the program with `sbt run` and enter your query into the prompt.
If you do not get the expected result you can try to adjust ChatGPT's
temperature. The temperature parameter in GPT-based models affects the
randomness of generated text, including SQL queries. A higher temperature leads
to more diverse and creative queries, but they may be less accurate or syntactically correct.
A lower temperature produces more focused and deterministic queries,
adhering closely to standard syntax. Choosing the temperature depends on the desired
trade-off between query accuracy and creativity. Lower values prioritize
correctness, while higher values introduce more randomness and exploration.
Experimentation with different temperature values helps find the right balance
for generating SQL queries from natural language.
## Example output
First I asked the program, to create the following tables and fill it with
example data.
Table: Customers
| Column Name | Data Type |
|--------------|----------|
| customer_id | integer |
| name | varchar |
| email | varchar |
| address | varchar |
| city | varchar |
| state | varchar |
| country | varchar |
Table: Orders
| Column Name | Data Type |
|-------------|----------|
| order_id | integer |
| customer_id | integer |
| order_date | date |
| total_amount| decimal |
| status | varchar |
```
$ create the follwing tables and fill it with 10 example customers and 20 example orders: Table: Customers Columns: customer_id (integer) primary key, name (varchar), email (varchar), address (varchar), city (varchar), state (varchar), country (varchar) Table: Orders Columns: order_id (integer) primary key, customer_id (integer), order_date (date), total_amount (decimal), status (varchar)
CREATE TABLE Customers CREATE TABLE Customers (
customer_id INTEGER PRIMARY KEY,
name VARCHAR,
email VARCHAR,
address VARCHAR,
city VARCHAR,
state VARCHAR,
country VARCHAR
);
CREATE TABLE Orders (
order_id INTEGER PRIMARY KEY,
customer_id INTEGER,
order_date DATE,
total_amount DECIMAL,
status VARCHAR
);
INSERT INTO Customers (customer_id, name, email, address, city, state, country)
VALUES (1, 'John Doe', 'johndoe@example.com', '123 Main St', 'New York', 'NY', 'USA'),
(2, 'Jane Smith', 'janesmith@example.com', '456 Elm St', 'Los Angeles', 'CA', 'USA'),
(3, 'Robert Johnson', 'robertjohnson@example.com', '789 Oak St', 'Chicago', 'IL', 'USA'),
(4, 'Emily Davis', 'emilydavis@example.com', '321 Pine St', 'Houston', 'TX', 'USA'),
(5, 'Michael Wilson', 'michaelwilson@example.com', '654 Cedar St', 'Miami', 'FL', 'USA'),
(6, 'Jessica Brown', 'jessicabrown@example.com', '987 Birch St', 'San Francisco', 'CA', 'USA'),
(7, 'David Anderson', 'davidanderson@example.com', '159 Maple St', 'Seattle', 'WA', 'USA'),
(8, 'Jennifer Martinez', 'jennifermartinez@example.com', '753 Spruce St', 'Boston', 'MA', 'USA'),
(9, 'Daniel Thompson', 'danielthompson@example.com', '852 Walnut St', 'Phoenix', 'AZ', 'USA'),
(10, 'Sarah Garcia', 'sarahgarcia@example.com', '369 Ash St', 'Denver', 'CO', 'USA');
INSERT INTO Orders (order_id, customer_id, order_date, total_amount, status)
VALUES (1, 1, '2021-01-01', 100.00, 'Completed'),
(2, 1, '2021-02-15', 50.00, 'Pending'),
(3, 2, '2021-03-10', 200.00, 'Completed'),
(4, 3, '2021-04-05', 75.00, 'Completed'),
(5, 4, '2021-05-20', 120.00, 'Pending'),
(6, 5, '2021-06-12', 80.00, 'Completed'),
(7, 6, '2021-07-18', 150.00, 'Pending'),
(8, 7, '2021-08-09', 90.00, 'Completed'),
(9, 8, '2021-09-25', 180.00, 'Pending'),
(10, 9, '2021-10-30', 60.00, 'Completed'),
(11, 10, '2021-11-22', 135.00, 'Pending'),
(12, 2, '2021-12-15', 95.00, 'Completed'),
(13, 3, '2022-01-02', 70.00, 'Completed'),
(14, 4, '2022-02-25', 160.00, 'Pending'),
(15, 5, '2022-03-18', 110.00, 'Completed'),
(16, 6, '2022-04-14', 45.00, 'Completed'),
(17, 7, '2022-05-06', 75.00, 'Pending'),
(18, 8, '2022-06-29', 140.00, 'Completed'),
(19, 9, '2022-07-23', 105.00, 'Pending'),
(20, 10, '2022-08-12', 65.00, 'Completed');
```
This took a second attempt, the first time, it did not generate the example data
correctly because it did not write all examples out and used "..." instead.
I also wanted to add some more orders, so each customer does not have the same
amount of orders.
```
$ insert 15 different example orders and write out for all 15 orders. only use customer_id 1 to 10 and order_id above 20
INSERT INTO orders (order_id, customer_id, order_date, total_amount, status) VALUES
(21, 1, '2021-01-01', 100.00, 'Delivered'),
(22, 2, '2021-01-02', 150.00, 'Delivered'),
(23, 3, '2021-01-03', 200.00, 'Delivered'),
(24, 4, '2021-01-04', 250.00, 'Delivered'),
(25, 5, '2021-01-05', 300.00, 'Delivered'),
(26, 6, '2021-01-06', 350.00, 'Delivered'),
(27, 7, '2021-01-07', 400.00, 'Delivered'),
(28, 8, '2021-01-08', 450.00, 'Delivered'),
(29, 9, '2021-01-09', 500.00, 'Delivered'),
(30, 10, '2021-01-10', 550.00, 'Delivered'),
(31, 1, '2021-01-11', 600.00, 'Delivered'),
(32, 2, '2021-01-12', 650.00, 'Delivered'),
(33, 3, '2021-01-13', 700.00, 'Delivered'),
(34, 4, '2021-01-14', 750.00, 'Delivered'),
(35, 5, '2021-01-15', 800.00, 'Delivered');
```
I had to be a bit more specific here with the IDs because only the
schema is provided to ChatGPT and therefore can not know which IDs it
can use. It would be interesting to input the complete table into GPT,
but for large tables it will be to much data.
It also used 'Delivered' as status instead of 'Completed'. However, it
does not matter much for the follow-up questions.
Then I let ChatGPT themself come up with the following tasks:
1. Write a query to retrieve the total number of customers in the database.
```
SELECT COUNT(*) AS total_customers FROM customers;
total_customers: 10
```
2. Write a query to calculate the total revenue generated from all orders.
```
SELECT SUM(total_amount) AS total_revenue FROM orders;
total_revenue: 8855.00
```
3. Write a query to find the top 5 customers who have placed the highest total amount of orders.
```
SELECT c.customer_id, c.name, SUM(o.total_amount) as total_order_amount
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.name
ORDER BY total_order_amount DESC
LIMIT 5;
customer_id: 5 name: Michael Wilson total_order_amount: 1290.00
customer_id: 4 name: Emily Davis total_order_amount: 1280.00
customer_id: 2 name: Jane Smith total_order_amount: 1095.00
customer_id: 3 name: Robert Johnson total_order_amount: 1045.00
customer_id: 1 name: John Doe total_order_amount: 850.00`
```
As you can see the answer is not correct as it sums the money spend rather than
the amount of orders. I guess it is because the attribute is named
`total_amount`. But it worked with a different temperature of 1.5:
```
SELECT c.name, COUNT(o.order_id) AS total_orders
FROM customers AS c
JOIN orders AS o ON c.customer_id = o.customer_id
GROUP BY c.name
ORDER BY total_orders DESC
LIMIT 5;
```
4. Write a query to find the average order amount for each customer.
```
SELECT customers.customer_id, AVG(orders.total_amount) AS average_order_amount
FROM customers
LEFT JOIN orders ON customers.customer_id = orders.customer_id
GROUP BY customers.customer_id;
Execute the query? [Y/n]:
customer_id: 5 average_order_amount: 322.5000000000000000
customer_id: 4 average_order_amount: 320.0000000000000000
customer_id: 10 average_order_amount: 250.0000000000000000
customer_id: 6 average_order_amount: 181.6666666666666667
customer_id: 2 average_order_amount: 273.7500000000000000
customer_id: 7 average_order_amount: 188.3333333333333333
customer_id: 1 average_order_amount: 212.5000000000000000
customer_id: 8 average_order_amount: 256.6666666666666667
customer_id: 9 average_order_amount: 221.6666666666666667
customer_id: 3 average_order_amount: 261.2500000000000000
```
5. Write a query to find the customers who have not placed any orders.
I first added costumers for that.
```
SELECT customer_id, name, email, address, city, state, country
FROM customers
WHERE customer_id NOT IN (SELECT customer_id FROM orders)
customer_id: 11 name: Funny Name 1 email: funny1@example.com address: 123 Funny Address city: Funny City state: Funny State country: Funny Country
customer_id: 12 name: Funny Name 2 email: funny2@example.com address: 456 Funny Address city: Funny City state: Funny State country: Funny Country
customer_id: 13 name: Funny Name 3 email: funny3@example.com address: 789 Funny Address city: Funny City state: Funny State country: Funny Country
```
6. Write a query to retrieve the number of orders placed in each country.
```
SELECT country, COUNT(*) as order_count
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY country;
country: USA order_count: 35
```

View File

@ -12,13 +12,16 @@ import io.cequence.openaiscala.domain.ChatRole
import java.sql.{Connection, DriverManager, ResultSet, DatabaseMetaData}
import scala.io.StdIn.readLine
import org.postgresql.util.PSQLException
import java.io.File
import com.typesafe.config.ConfigFactory
@main def main(args: String*): Unit = {
val config = ConfigFactory.parseFile(new File("chatSql.conf"))
given ec: ExecutionContext = ExecutionContext.global
given actorSystem: ActorSystem = ActorSystem()
val service = OpenAIServiceFactory()
val service = OpenAIServiceFactory(config.getString("apiKey"))
Class.forName("org.postgresql.Driver")
val con_str = "jdbc:postgresql://localhost:5432/chatSql"
val con_str = config.getString("database")
val conn = DriverManager.getConnection(con_str)
val stm = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
while (true) {
@ -58,41 +61,57 @@ import org.postgresql.util.PSQLException
val input = schema + readLine("Enter a query: ")
val systemInfo = "Convert the following Sentence to an SQL query. Return only SQL, no explanation, do not warp it in a mardown code block"
val completion = Await.result(service.createChatCompletion(
Seq(MessageSpec(ChatRole.System, systemInfo), MessageSpec(ChatRole.User, input)),
settings = CreateChatCompletionSettings(
model = ModelId.gpt_3_5_turbo
)), Duration.Inf)
val query = completion.choices.head.message.content
println(query)
if (readLine("Execute the query? [Y/n]: ").toLowerCase != "n") {
try {
val resultSet = stm.executeQuery(query)
val metaData = resultSet.getMetaData
val columnCount = metaData.getColumnCount
val columnNames = (1 to columnCount).map(metaData.getColumnName)
var tryAgain = true
var temperature = 1.0
while (tryAgain) {
val completion = Await.result(service.createChatCompletion(
Seq(MessageSpec(ChatRole.System, systemInfo), MessageSpec(ChatRole.User, input)),
settings = CreateChatCompletionSettings(
model = ModelId.gpt_3_5_turbo,
temperature = Some(temperature)
)), Duration.Inf)
val query = completion.choices.head.message.content
println(query)
if (yesNoQuestion("Execute the query?")) {
tryAgain = false
try {
val resultSet = stm.executeQuery(query)
val metaData = resultSet.getMetaData
val columnCount = metaData.getColumnCount
val columnNames = (1 to columnCount).map(metaData.getColumnName)
// Process the query results
while (resultSet.next()) {
// Retrieve data for each column
columnNames.foreach { columnName =>
val columnValue = resultSet.getObject(columnName)
print(s"$columnName: $columnValue\t")
// Process the query results
while (resultSet.next()) {
// Retrieve data for each column
columnNames.foreach { columnName =>
val columnValue = resultSet.getObject(columnName)
print(s"$columnName: $columnValue\t")
}
println()
}
println()
resultSet.close()
}
resultSet.close()
}
catch {
case e: PSQLException => {
// ignore error: query has no output
if (e.getSQLState != "02000") {
println(e)
catch {
case e: PSQLException => {
// ignore error: query has no output
if (e.getSQLState != "02000") {
println(e)
}
}
}
}
else {
tryAgain = yesNoQuestion("Do you want to try it agian with a different temperature?")
if (tryAgain) {
temperature = readLine("Enter a temperature (double value) [0.0-2.0]: ").toDouble
}
}
}
}
stm.close()
conn.close()
}
def yesNoQuestion(question: String): Boolean = {
readLine(s"$question [Y/n]: ").toLowerCase != "n"
}