From 16ef0d0ebd6d8f0bd6581cffefc47a5d2145a6d9 Mon Sep 17 00:00:00 2001 From: MrGeorgen Date: Sun, 16 Jul 2023 14:37:12 +0200 Subject: [PATCH] README --- .gitignore | 2 +- README.md | 236 +++++++++++++++++++++++++++++ src/main/scala/chat_sql/main.scala | 77 ++++++---- 3 files changed, 285 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index a203883..0fcc4cc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ !/LICENSE !/build.sbt !/README.md -openai-scala-client.conf \ No newline at end of file +openai-scala-client.conf diff --git a/README.md b/README.md index c1c7d83..c1e684b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,238 @@ + + # chat-sql +chat-sql is a cli, which converts queries written in natural language to SQL +queries via ChatGPT. It connects a PostgreSQL database, to execute the +queries directly. The schema of the all tables of the database is sent to ChatGPT +in order to provide better results. + +## Setup + +Make sure you have a Scala 3 compiler and sbt installed. Then you can clone the +git repository. + +To use the ChatGPT api you have to create an API key on the +[OpenAI Platform](https://platform.openai.com/account/api-keys). You also need +to buy tokens or use the tokens you get for free the first three months after +you created your OpenAI account. + +To configure chat-sql, create a config file `chatSql.conf` in the git directory +with the following content: + +``` +database = "jdbc:postgresql://localhost:5432/database_name?user=username&password=password" +apiKey = "your api key" +``` + +Replace the values accordingly. + +## Usage + +Start the program with `sbt run` and enter your query into the prompt. + +If you do not get the expected result you can try to adjust ChatGPT's +temperature. The temperature parameter in GPT-based models affects the +randomness of generated text, including SQL queries. A higher temperature leads +to more diverse and creative queries, but they may be less accurate or syntactically correct. +A lower temperature produces more focused and deterministic queries, +adhering closely to standard syntax. Choosing the temperature depends on the desired +trade-off between query accuracy and creativity. Lower values prioritize +correctness, while higher values introduce more randomness and exploration. +Experimentation with different temperature values helps find the right balance +for generating SQL queries from natural language. + +## Example output + +First I asked the program, to create the following tables and fill it with +example data. + +Table: Customers + +| Column Name | Data Type | +|--------------|----------| +| customer_id | integer | +| name | varchar | +| email | varchar | +| address | varchar | +| city | varchar | +| state | varchar | +| country | varchar | + +Table: Orders + +| Column Name | Data Type | +|-------------|----------| +| order_id | integer | +| customer_id | integer | +| order_date | date | +| total_amount| decimal | +| status | varchar | + +``` +$ create the follwing tables and fill it with 10 example customers and 20 example orders: Table: Customers Columns: customer_id (integer) primary key, name (varchar), email (varchar), address (varchar), city (varchar), state (varchar), country (varchar) Table: Orders Columns: order_id (integer) primary key, customer_id (integer), order_date (date), total_amount (decimal), status (varchar) +CREATE TABLE Customers CREATE TABLE Customers ( + customer_id INTEGER PRIMARY KEY, + name VARCHAR, + email VARCHAR, + address VARCHAR, + city VARCHAR, + state VARCHAR, + country VARCHAR +); + +CREATE TABLE Orders ( + order_id INTEGER PRIMARY KEY, + customer_id INTEGER, + order_date DATE, + total_amount DECIMAL, + status VARCHAR +); + +INSERT INTO Customers (customer_id, name, email, address, city, state, country) +VALUES (1, 'John Doe', 'johndoe@example.com', '123 Main St', 'New York', 'NY', 'USA'), + (2, 'Jane Smith', 'janesmith@example.com', '456 Elm St', 'Los Angeles', 'CA', 'USA'), + (3, 'Robert Johnson', 'robertjohnson@example.com', '789 Oak St', 'Chicago', 'IL', 'USA'), + (4, 'Emily Davis', 'emilydavis@example.com', '321 Pine St', 'Houston', 'TX', 'USA'), + (5, 'Michael Wilson', 'michaelwilson@example.com', '654 Cedar St', 'Miami', 'FL', 'USA'), + (6, 'Jessica Brown', 'jessicabrown@example.com', '987 Birch St', 'San Francisco', 'CA', 'USA'), + (7, 'David Anderson', 'davidanderson@example.com', '159 Maple St', 'Seattle', 'WA', 'USA'), + (8, 'Jennifer Martinez', 'jennifermartinez@example.com', '753 Spruce St', 'Boston', 'MA', 'USA'), + (9, 'Daniel Thompson', 'danielthompson@example.com', '852 Walnut St', 'Phoenix', 'AZ', 'USA'), + (10, 'Sarah Garcia', 'sarahgarcia@example.com', '369 Ash St', 'Denver', 'CO', 'USA'); + +INSERT INTO Orders (order_id, customer_id, order_date, total_amount, status) +VALUES (1, 1, '2021-01-01', 100.00, 'Completed'), + (2, 1, '2021-02-15', 50.00, 'Pending'), + (3, 2, '2021-03-10', 200.00, 'Completed'), + (4, 3, '2021-04-05', 75.00, 'Completed'), + (5, 4, '2021-05-20', 120.00, 'Pending'), + (6, 5, '2021-06-12', 80.00, 'Completed'), + (7, 6, '2021-07-18', 150.00, 'Pending'), + (8, 7, '2021-08-09', 90.00, 'Completed'), + (9, 8, '2021-09-25', 180.00, 'Pending'), + (10, 9, '2021-10-30', 60.00, 'Completed'), + (11, 10, '2021-11-22', 135.00, 'Pending'), + (12, 2, '2021-12-15', 95.00, 'Completed'), + (13, 3, '2022-01-02', 70.00, 'Completed'), + (14, 4, '2022-02-25', 160.00, 'Pending'), + (15, 5, '2022-03-18', 110.00, 'Completed'), + (16, 6, '2022-04-14', 45.00, 'Completed'), + (17, 7, '2022-05-06', 75.00, 'Pending'), + (18, 8, '2022-06-29', 140.00, 'Completed'), + (19, 9, '2022-07-23', 105.00, 'Pending'), + (20, 10, '2022-08-12', 65.00, 'Completed'); +``` +This took a second attempt, the first time, it did not generate the example data +correctly because it did not write all examples out and used "..." instead. + +I also wanted to add some more orders, so each customer does not have the same +amount of orders. +``` +$ insert 15 different example orders and write out for all 15 orders. only use customer_id 1 to 10 and order_id above 20 +INSERT INTO orders (order_id, customer_id, order_date, total_amount, status) VALUES +(21, 1, '2021-01-01', 100.00, 'Delivered'), +(22, 2, '2021-01-02', 150.00, 'Delivered'), +(23, 3, '2021-01-03', 200.00, 'Delivered'), +(24, 4, '2021-01-04', 250.00, 'Delivered'), +(25, 5, '2021-01-05', 300.00, 'Delivered'), +(26, 6, '2021-01-06', 350.00, 'Delivered'), +(27, 7, '2021-01-07', 400.00, 'Delivered'), +(28, 8, '2021-01-08', 450.00, 'Delivered'), +(29, 9, '2021-01-09', 500.00, 'Delivered'), +(30, 10, '2021-01-10', 550.00, 'Delivered'), +(31, 1, '2021-01-11', 600.00, 'Delivered'), +(32, 2, '2021-01-12', 650.00, 'Delivered'), +(33, 3, '2021-01-13', 700.00, 'Delivered'), +(34, 4, '2021-01-14', 750.00, 'Delivered'), +(35, 5, '2021-01-15', 800.00, 'Delivered'); +``` +I had to be a bit more specific here with the IDs because only the +schema is provided to ChatGPT and therefore can not know which IDs it +can use. It would be interesting to input the complete table into GPT, +but for large tables it will be to much data. +It also used 'Delivered' as status instead of 'Completed'. However, it +does not matter much for the follow-up questions. + +Then I let ChatGPT themself come up with the following tasks: + +1. Write a query to retrieve the total number of customers in the database. +``` +SELECT COUNT(*) AS total_customers FROM customers; +total_customers: 10 +``` + +2. Write a query to calculate the total revenue generated from all orders. +``` +SELECT SUM(total_amount) AS total_revenue FROM orders; +total_revenue: 8855.00 +``` + +3. Write a query to find the top 5 customers who have placed the highest total amount of orders. + + +``` +SELECT c.customer_id, c.name, SUM(o.total_amount) as total_order_amount +FROM customers c +JOIN orders o ON c.customer_id = o.customer_id +GROUP BY c.customer_id, c.name +ORDER BY total_order_amount DESC +LIMIT 5; +customer_id: 5 name: Michael Wilson total_order_amount: 1290.00 +customer_id: 4 name: Emily Davis total_order_amount: 1280.00 +customer_id: 2 name: Jane Smith total_order_amount: 1095.00 +customer_id: 3 name: Robert Johnson total_order_amount: 1045.00 +customer_id: 1 name: John Doe total_order_amount: 850.00` +``` + +As you can see the answer is not correct as it sums the money spend rather than +the amount of orders. I guess it is because the attribute is named +`total_amount`. But it worked with a different temperature of 1.5: +``` +SELECT c.name, COUNT(o.order_id) AS total_orders +FROM customers AS c +JOIN orders AS o ON c.customer_id = o.customer_id +GROUP BY c.name +ORDER BY total_orders DESC +LIMIT 5; +``` + +4. Write a query to find the average order amount for each customer. +``` +SELECT customers.customer_id, AVG(orders.total_amount) AS average_order_amount +FROM customers +LEFT JOIN orders ON customers.customer_id = orders.customer_id +GROUP BY customers.customer_id; +Execute the query? [Y/n]: +customer_id: 5 average_order_amount: 322.5000000000000000 +customer_id: 4 average_order_amount: 320.0000000000000000 +customer_id: 10 average_order_amount: 250.0000000000000000 +customer_id: 6 average_order_amount: 181.6666666666666667 +customer_id: 2 average_order_amount: 273.7500000000000000 +customer_id: 7 average_order_amount: 188.3333333333333333 +customer_id: 1 average_order_amount: 212.5000000000000000 +customer_id: 8 average_order_amount: 256.6666666666666667 +customer_id: 9 average_order_amount: 221.6666666666666667 +customer_id: 3 average_order_amount: 261.2500000000000000 +``` + +5. Write a query to find the customers who have not placed any orders. + +I first added costumers for that. +``` +SELECT customer_id, name, email, address, city, state, country +FROM customers +WHERE customer_id NOT IN (SELECT customer_id FROM orders) +customer_id: 11 name: Funny Name 1 email: funny1@example.com address: 123 Funny Address city: Funny City state: Funny State country: Funny Country +customer_id: 12 name: Funny Name 2 email: funny2@example.com address: 456 Funny Address city: Funny City state: Funny State country: Funny Country +customer_id: 13 name: Funny Name 3 email: funny3@example.com address: 789 Funny Address city: Funny City state: Funny State country: Funny Country +``` + +6. Write a query to retrieve the number of orders placed in each country. +``` +SELECT country, COUNT(*) as order_count +FROM customers c +JOIN orders o ON c.customer_id = o.customer_id +GROUP BY country; +country: USA order_count: 35 +``` diff --git a/src/main/scala/chat_sql/main.scala b/src/main/scala/chat_sql/main.scala index 153265c..735aacb 100644 --- a/src/main/scala/chat_sql/main.scala +++ b/src/main/scala/chat_sql/main.scala @@ -12,13 +12,16 @@ import io.cequence.openaiscala.domain.ChatRole import java.sql.{Connection, DriverManager, ResultSet, DatabaseMetaData} import scala.io.StdIn.readLine import org.postgresql.util.PSQLException +import java.io.File +import com.typesafe.config.ConfigFactory @main def main(args: String*): Unit = { + val config = ConfigFactory.parseFile(new File("chatSql.conf")) given ec: ExecutionContext = ExecutionContext.global given actorSystem: ActorSystem = ActorSystem() - val service = OpenAIServiceFactory() + val service = OpenAIServiceFactory(config.getString("apiKey")) Class.forName("org.postgresql.Driver") - val con_str = "jdbc:postgresql://localhost:5432/chatSql" + val con_str = config.getString("database") val conn = DriverManager.getConnection(con_str) val stm = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) while (true) { @@ -58,41 +61,57 @@ import org.postgresql.util.PSQLException val input = schema + readLine("Enter a query: ") val systemInfo = "Convert the following Sentence to an SQL query. Return only SQL, no explanation, do not warp it in a mardown code block" - val completion = Await.result(service.createChatCompletion( - Seq(MessageSpec(ChatRole.System, systemInfo), MessageSpec(ChatRole.User, input)), - settings = CreateChatCompletionSettings( - model = ModelId.gpt_3_5_turbo - )), Duration.Inf) - val query = completion.choices.head.message.content - println(query) - if (readLine("Execute the query? [Y/n]: ").toLowerCase != "n") { - try { - val resultSet = stm.executeQuery(query) - val metaData = resultSet.getMetaData - val columnCount = metaData.getColumnCount - val columnNames = (1 to columnCount).map(metaData.getColumnName) + var tryAgain = true + var temperature = 1.0 + while (tryAgain) { + val completion = Await.result(service.createChatCompletion( + Seq(MessageSpec(ChatRole.System, systemInfo), MessageSpec(ChatRole.User, input)), + settings = CreateChatCompletionSettings( + model = ModelId.gpt_3_5_turbo, + temperature = Some(temperature) + )), Duration.Inf) + val query = completion.choices.head.message.content + println(query) + if (yesNoQuestion("Execute the query?")) { + tryAgain = false + try { + val resultSet = stm.executeQuery(query) + val metaData = resultSet.getMetaData + val columnCount = metaData.getColumnCount + val columnNames = (1 to columnCount).map(metaData.getColumnName) - // Process the query results - while (resultSet.next()) { - // Retrieve data for each column - columnNames.foreach { columnName => - val columnValue = resultSet.getObject(columnName) - print(s"$columnName: $columnValue\t") + // Process the query results + while (resultSet.next()) { + // Retrieve data for each column + columnNames.foreach { columnName => + val columnValue = resultSet.getObject(columnName) + print(s"$columnName: $columnValue\t") + } + println() } - println() + resultSet.close() } - resultSet.close() - } - catch { - case e: PSQLException => { - // ignore error: query has no output - if (e.getSQLState != "02000") { - println(e) + catch { + case e: PSQLException => { + // ignore error: query has no output + if (e.getSQLState != "02000") { + println(e) + } } } } + else { + tryAgain = yesNoQuestion("Do you want to try it agian with a different temperature?") + if (tryAgain) { + temperature = readLine("Enter a temperature (double value) [0.0-2.0]: ").toDouble + } + } } } stm.close() conn.close() } + +def yesNoQuestion(question: String): Boolean = { + readLine(s"$question [Y/n]: ").toLowerCase != "n" +}