README
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,4 +3,4 @@
|
||||
!/LICENSE
|
||||
!/build.sbt
|
||||
!/README.md
|
||||
openai-scala-client.conf
|
||||
openai-scala-client.conf
|
||||
|
||||
236
README.md
236
README.md
@ -1,2 +1,238 @@
|
||||
<!-- LTeX: language=en-US -->
|
||||
|
||||
# chat-sql
|
||||
|
||||
chat-sql is a cli, which converts queries written in natural language to SQL
|
||||
queries via ChatGPT. It connects a PostgreSQL database, to execute the
|
||||
queries directly. The schema of the all tables of the database is sent to ChatGPT
|
||||
in order to provide better results.
|
||||
|
||||
## Setup
|
||||
|
||||
Make sure you have a Scala 3 compiler and sbt installed. Then you can clone the
|
||||
git repository.
|
||||
|
||||
To use the ChatGPT api you have to create an API key on the
|
||||
[OpenAI Platform](https://platform.openai.com/account/api-keys). You also need
|
||||
to buy tokens or use the tokens you get for free the first three months after
|
||||
you created your OpenAI account.
|
||||
|
||||
To configure chat-sql, create a config file `chatSql.conf` in the git directory
|
||||
with the following content:
|
||||
|
||||
```
|
||||
database = "jdbc:postgresql://localhost:5432/database_name?user=username&password=password"
|
||||
apiKey = "your api key"
|
||||
```
|
||||
|
||||
Replace the values accordingly.
|
||||
|
||||
## Usage
|
||||
|
||||
Start the program with `sbt run` and enter your query into the prompt.
|
||||
|
||||
If you do not get the expected result you can try to adjust ChatGPT's
|
||||
temperature. The temperature parameter in GPT-based models affects the
|
||||
randomness of generated text, including SQL queries. A higher temperature leads
|
||||
to more diverse and creative queries, but they may be less accurate or syntactically correct.
|
||||
A lower temperature produces more focused and deterministic queries,
|
||||
adhering closely to standard syntax. Choosing the temperature depends on the desired
|
||||
trade-off between query accuracy and creativity. Lower values prioritize
|
||||
correctness, while higher values introduce more randomness and exploration.
|
||||
Experimentation with different temperature values helps find the right balance
|
||||
for generating SQL queries from natural language.
|
||||
|
||||
## Example output
|
||||
|
||||
First I asked the program, to create the following tables and fill it with
|
||||
example data.
|
||||
|
||||
Table: Customers
|
||||
|
||||
| Column Name | Data Type |
|
||||
|--------------|----------|
|
||||
| customer_id | integer |
|
||||
| name | varchar |
|
||||
| email | varchar |
|
||||
| address | varchar |
|
||||
| city | varchar |
|
||||
| state | varchar |
|
||||
| country | varchar |
|
||||
|
||||
Table: Orders
|
||||
|
||||
| Column Name | Data Type |
|
||||
|-------------|----------|
|
||||
| order_id | integer |
|
||||
| customer_id | integer |
|
||||
| order_date | date |
|
||||
| total_amount| decimal |
|
||||
| status | varchar |
|
||||
|
||||
```
|
||||
$ create the follwing tables and fill it with 10 example customers and 20 example orders: Table: Customers Columns: customer_id (integer) primary key, name (varchar), email (varchar), address (varchar), city (varchar), state (varchar), country (varchar) Table: Orders Columns: order_id (integer) primary key, customer_id (integer), order_date (date), total_amount (decimal), status (varchar)
|
||||
CREATE TABLE Customers CREATE TABLE Customers (
|
||||
customer_id INTEGER PRIMARY KEY,
|
||||
name VARCHAR,
|
||||
email VARCHAR,
|
||||
address VARCHAR,
|
||||
city VARCHAR,
|
||||
state VARCHAR,
|
||||
country VARCHAR
|
||||
);
|
||||
|
||||
CREATE TABLE Orders (
|
||||
order_id INTEGER PRIMARY KEY,
|
||||
customer_id INTEGER,
|
||||
order_date DATE,
|
||||
total_amount DECIMAL,
|
||||
status VARCHAR
|
||||
);
|
||||
|
||||
INSERT INTO Customers (customer_id, name, email, address, city, state, country)
|
||||
VALUES (1, 'John Doe', 'johndoe@example.com', '123 Main St', 'New York', 'NY', 'USA'),
|
||||
(2, 'Jane Smith', 'janesmith@example.com', '456 Elm St', 'Los Angeles', 'CA', 'USA'),
|
||||
(3, 'Robert Johnson', 'robertjohnson@example.com', '789 Oak St', 'Chicago', 'IL', 'USA'),
|
||||
(4, 'Emily Davis', 'emilydavis@example.com', '321 Pine St', 'Houston', 'TX', 'USA'),
|
||||
(5, 'Michael Wilson', 'michaelwilson@example.com', '654 Cedar St', 'Miami', 'FL', 'USA'),
|
||||
(6, 'Jessica Brown', 'jessicabrown@example.com', '987 Birch St', 'San Francisco', 'CA', 'USA'),
|
||||
(7, 'David Anderson', 'davidanderson@example.com', '159 Maple St', 'Seattle', 'WA', 'USA'),
|
||||
(8, 'Jennifer Martinez', 'jennifermartinez@example.com', '753 Spruce St', 'Boston', 'MA', 'USA'),
|
||||
(9, 'Daniel Thompson', 'danielthompson@example.com', '852 Walnut St', 'Phoenix', 'AZ', 'USA'),
|
||||
(10, 'Sarah Garcia', 'sarahgarcia@example.com', '369 Ash St', 'Denver', 'CO', 'USA');
|
||||
|
||||
INSERT INTO Orders (order_id, customer_id, order_date, total_amount, status)
|
||||
VALUES (1, 1, '2021-01-01', 100.00, 'Completed'),
|
||||
(2, 1, '2021-02-15', 50.00, 'Pending'),
|
||||
(3, 2, '2021-03-10', 200.00, 'Completed'),
|
||||
(4, 3, '2021-04-05', 75.00, 'Completed'),
|
||||
(5, 4, '2021-05-20', 120.00, 'Pending'),
|
||||
(6, 5, '2021-06-12', 80.00, 'Completed'),
|
||||
(7, 6, '2021-07-18', 150.00, 'Pending'),
|
||||
(8, 7, '2021-08-09', 90.00, 'Completed'),
|
||||
(9, 8, '2021-09-25', 180.00, 'Pending'),
|
||||
(10, 9, '2021-10-30', 60.00, 'Completed'),
|
||||
(11, 10, '2021-11-22', 135.00, 'Pending'),
|
||||
(12, 2, '2021-12-15', 95.00, 'Completed'),
|
||||
(13, 3, '2022-01-02', 70.00, 'Completed'),
|
||||
(14, 4, '2022-02-25', 160.00, 'Pending'),
|
||||
(15, 5, '2022-03-18', 110.00, 'Completed'),
|
||||
(16, 6, '2022-04-14', 45.00, 'Completed'),
|
||||
(17, 7, '2022-05-06', 75.00, 'Pending'),
|
||||
(18, 8, '2022-06-29', 140.00, 'Completed'),
|
||||
(19, 9, '2022-07-23', 105.00, 'Pending'),
|
||||
(20, 10, '2022-08-12', 65.00, 'Completed');
|
||||
```
|
||||
This took a second attempt, the first time, it did not generate the example data
|
||||
correctly because it did not write all examples out and used "..." instead.
|
||||
|
||||
I also wanted to add some more orders, so each customer does not have the same
|
||||
amount of orders.
|
||||
```
|
||||
$ insert 15 different example orders and write out for all 15 orders. only use customer_id 1 to 10 and order_id above 20
|
||||
INSERT INTO orders (order_id, customer_id, order_date, total_amount, status) VALUES
|
||||
(21, 1, '2021-01-01', 100.00, 'Delivered'),
|
||||
(22, 2, '2021-01-02', 150.00, 'Delivered'),
|
||||
(23, 3, '2021-01-03', 200.00, 'Delivered'),
|
||||
(24, 4, '2021-01-04', 250.00, 'Delivered'),
|
||||
(25, 5, '2021-01-05', 300.00, 'Delivered'),
|
||||
(26, 6, '2021-01-06', 350.00, 'Delivered'),
|
||||
(27, 7, '2021-01-07', 400.00, 'Delivered'),
|
||||
(28, 8, '2021-01-08', 450.00, 'Delivered'),
|
||||
(29, 9, '2021-01-09', 500.00, 'Delivered'),
|
||||
(30, 10, '2021-01-10', 550.00, 'Delivered'),
|
||||
(31, 1, '2021-01-11', 600.00, 'Delivered'),
|
||||
(32, 2, '2021-01-12', 650.00, 'Delivered'),
|
||||
(33, 3, '2021-01-13', 700.00, 'Delivered'),
|
||||
(34, 4, '2021-01-14', 750.00, 'Delivered'),
|
||||
(35, 5, '2021-01-15', 800.00, 'Delivered');
|
||||
```
|
||||
I had to be a bit more specific here with the IDs because only the
|
||||
schema is provided to ChatGPT and therefore can not know which IDs it
|
||||
can use. It would be interesting to input the complete table into GPT,
|
||||
but for large tables it will be to much data.
|
||||
It also used 'Delivered' as status instead of 'Completed'. However, it
|
||||
does not matter much for the follow-up questions.
|
||||
|
||||
Then I let ChatGPT themself come up with the following tasks:
|
||||
|
||||
1. Write a query to retrieve the total number of customers in the database.
|
||||
```
|
||||
SELECT COUNT(*) AS total_customers FROM customers;
|
||||
total_customers: 10
|
||||
```
|
||||
|
||||
2. Write a query to calculate the total revenue generated from all orders.
|
||||
```
|
||||
SELECT SUM(total_amount) AS total_revenue FROM orders;
|
||||
total_revenue: 8855.00
|
||||
```
|
||||
|
||||
3. Write a query to find the top 5 customers who have placed the highest total amount of orders.
|
||||
|
||||
|
||||
```
|
||||
SELECT c.customer_id, c.name, SUM(o.total_amount) as total_order_amount
|
||||
FROM customers c
|
||||
JOIN orders o ON c.customer_id = o.customer_id
|
||||
GROUP BY c.customer_id, c.name
|
||||
ORDER BY total_order_amount DESC
|
||||
LIMIT 5;
|
||||
customer_id: 5 name: Michael Wilson total_order_amount: 1290.00
|
||||
customer_id: 4 name: Emily Davis total_order_amount: 1280.00
|
||||
customer_id: 2 name: Jane Smith total_order_amount: 1095.00
|
||||
customer_id: 3 name: Robert Johnson total_order_amount: 1045.00
|
||||
customer_id: 1 name: John Doe total_order_amount: 850.00`
|
||||
```
|
||||
|
||||
As you can see the answer is not correct as it sums the money spend rather than
|
||||
the amount of orders. I guess it is because the attribute is named
|
||||
`total_amount`. But it worked with a different temperature of 1.5:
|
||||
```
|
||||
SELECT c.name, COUNT(o.order_id) AS total_orders
|
||||
FROM customers AS c
|
||||
JOIN orders AS o ON c.customer_id = o.customer_id
|
||||
GROUP BY c.name
|
||||
ORDER BY total_orders DESC
|
||||
LIMIT 5;
|
||||
```
|
||||
|
||||
4. Write a query to find the average order amount for each customer.
|
||||
```
|
||||
SELECT customers.customer_id, AVG(orders.total_amount) AS average_order_amount
|
||||
FROM customers
|
||||
LEFT JOIN orders ON customers.customer_id = orders.customer_id
|
||||
GROUP BY customers.customer_id;
|
||||
Execute the query? [Y/n]:
|
||||
customer_id: 5 average_order_amount: 322.5000000000000000
|
||||
customer_id: 4 average_order_amount: 320.0000000000000000
|
||||
customer_id: 10 average_order_amount: 250.0000000000000000
|
||||
customer_id: 6 average_order_amount: 181.6666666666666667
|
||||
customer_id: 2 average_order_amount: 273.7500000000000000
|
||||
customer_id: 7 average_order_amount: 188.3333333333333333
|
||||
customer_id: 1 average_order_amount: 212.5000000000000000
|
||||
customer_id: 8 average_order_amount: 256.6666666666666667
|
||||
customer_id: 9 average_order_amount: 221.6666666666666667
|
||||
customer_id: 3 average_order_amount: 261.2500000000000000
|
||||
```
|
||||
|
||||
5. Write a query to find the customers who have not placed any orders.
|
||||
|
||||
I first added costumers for that.
|
||||
```
|
||||
SELECT customer_id, name, email, address, city, state, country
|
||||
FROM customers
|
||||
WHERE customer_id NOT IN (SELECT customer_id FROM orders)
|
||||
customer_id: 11 name: Funny Name 1 email: funny1@example.com address: 123 Funny Address city: Funny City state: Funny State country: Funny Country
|
||||
customer_id: 12 name: Funny Name 2 email: funny2@example.com address: 456 Funny Address city: Funny City state: Funny State country: Funny Country
|
||||
customer_id: 13 name: Funny Name 3 email: funny3@example.com address: 789 Funny Address city: Funny City state: Funny State country: Funny Country
|
||||
```
|
||||
|
||||
6. Write a query to retrieve the number of orders placed in each country.
|
||||
```
|
||||
SELECT country, COUNT(*) as order_count
|
||||
FROM customers c
|
||||
JOIN orders o ON c.customer_id = o.customer_id
|
||||
GROUP BY country;
|
||||
country: USA order_count: 35
|
||||
```
|
||||
|
||||
@ -12,13 +12,16 @@ import io.cequence.openaiscala.domain.ChatRole
|
||||
import java.sql.{Connection, DriverManager, ResultSet, DatabaseMetaData}
|
||||
import scala.io.StdIn.readLine
|
||||
import org.postgresql.util.PSQLException
|
||||
import java.io.File
|
||||
import com.typesafe.config.ConfigFactory
|
||||
|
||||
@main def main(args: String*): Unit = {
|
||||
val config = ConfigFactory.parseFile(new File("chatSql.conf"))
|
||||
given ec: ExecutionContext = ExecutionContext.global
|
||||
given actorSystem: ActorSystem = ActorSystem()
|
||||
val service = OpenAIServiceFactory()
|
||||
val service = OpenAIServiceFactory(config.getString("apiKey"))
|
||||
Class.forName("org.postgresql.Driver")
|
||||
val con_str = "jdbc:postgresql://localhost:5432/chatSql"
|
||||
val con_str = config.getString("database")
|
||||
val conn = DriverManager.getConnection(con_str)
|
||||
val stm = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
|
||||
while (true) {
|
||||
@ -58,41 +61,57 @@ import org.postgresql.util.PSQLException
|
||||
|
||||
val input = schema + readLine("Enter a query: ")
|
||||
val systemInfo = "Convert the following Sentence to an SQL query. Return only SQL, no explanation, do not warp it in a mardown code block"
|
||||
val completion = Await.result(service.createChatCompletion(
|
||||
Seq(MessageSpec(ChatRole.System, systemInfo), MessageSpec(ChatRole.User, input)),
|
||||
settings = CreateChatCompletionSettings(
|
||||
model = ModelId.gpt_3_5_turbo
|
||||
)), Duration.Inf)
|
||||
val query = completion.choices.head.message.content
|
||||
println(query)
|
||||
if (readLine("Execute the query? [Y/n]: ").toLowerCase != "n") {
|
||||
try {
|
||||
val resultSet = stm.executeQuery(query)
|
||||
val metaData = resultSet.getMetaData
|
||||
val columnCount = metaData.getColumnCount
|
||||
val columnNames = (1 to columnCount).map(metaData.getColumnName)
|
||||
var tryAgain = true
|
||||
var temperature = 1.0
|
||||
while (tryAgain) {
|
||||
val completion = Await.result(service.createChatCompletion(
|
||||
Seq(MessageSpec(ChatRole.System, systemInfo), MessageSpec(ChatRole.User, input)),
|
||||
settings = CreateChatCompletionSettings(
|
||||
model = ModelId.gpt_3_5_turbo,
|
||||
temperature = Some(temperature)
|
||||
)), Duration.Inf)
|
||||
val query = completion.choices.head.message.content
|
||||
println(query)
|
||||
if (yesNoQuestion("Execute the query?")) {
|
||||
tryAgain = false
|
||||
try {
|
||||
val resultSet = stm.executeQuery(query)
|
||||
val metaData = resultSet.getMetaData
|
||||
val columnCount = metaData.getColumnCount
|
||||
val columnNames = (1 to columnCount).map(metaData.getColumnName)
|
||||
|
||||
// Process the query results
|
||||
while (resultSet.next()) {
|
||||
// Retrieve data for each column
|
||||
columnNames.foreach { columnName =>
|
||||
val columnValue = resultSet.getObject(columnName)
|
||||
print(s"$columnName: $columnValue\t")
|
||||
// Process the query results
|
||||
while (resultSet.next()) {
|
||||
// Retrieve data for each column
|
||||
columnNames.foreach { columnName =>
|
||||
val columnValue = resultSet.getObject(columnName)
|
||||
print(s"$columnName: $columnValue\t")
|
||||
}
|
||||
println()
|
||||
}
|
||||
println()
|
||||
resultSet.close()
|
||||
}
|
||||
resultSet.close()
|
||||
}
|
||||
catch {
|
||||
case e: PSQLException => {
|
||||
// ignore error: query has no output
|
||||
if (e.getSQLState != "02000") {
|
||||
println(e)
|
||||
catch {
|
||||
case e: PSQLException => {
|
||||
// ignore error: query has no output
|
||||
if (e.getSQLState != "02000") {
|
||||
println(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
tryAgain = yesNoQuestion("Do you want to try it agian with a different temperature?")
|
||||
if (tryAgain) {
|
||||
temperature = readLine("Enter a temperature (double value) [0.0-2.0]: ").toDouble
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stm.close()
|
||||
conn.close()
|
||||
}
|
||||
|
||||
def yesNoQuestion(question: String): Boolean = {
|
||||
readLine(s"$question [Y/n]: ").toLowerCase != "n"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user