From f9d339adbc15ff6212a49c3d8d726372f44d20dd Mon Sep 17 00:00:00 2001 From: MrGeorgen Date: Sun, 25 Jun 2023 19:13:54 +0200 Subject: [PATCH] basic version --- src/main/scala/chat_sql/main.scala | 73 ++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 10 deletions(-) diff --git a/src/main/scala/chat_sql/main.scala b/src/main/scala/chat_sql/main.scala index 087465a..dd762af 100644 --- a/src/main/scala/chat_sql/main.scala +++ b/src/main/scala/chat_sql/main.scala @@ -9,7 +9,7 @@ import io.cequence.openaiscala.domain.ModelId import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings import io.cequence.openaiscala.domain.MessageSpec import io.cequence.openaiscala.domain.ChatRole -import java.sql.{Connection, DriverManager, ResultSet} +import java.sql.{Connection, DriverManager, ResultSet, DatabaseMetaData} import scala.io.StdIn.readLine import org.postgresql.util.PSQLException @@ -21,9 +21,43 @@ import org.postgresql.util.PSQLException val con_str = "jdbc:postgresql://localhost:5432/chatSql" val conn = DriverManager.getConnection(con_str) val stm = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + var schema = "" while (true) { - val input = readLine("Enter a query: ") - val systemInfo = "Convert the following Sentence to an SQL query. Return only SQL" + // Get the DatabaseMetaData + val metaData: DatabaseMetaData = conn.getMetaData + + // Get the tables' schema information + val tablesResultSet: ResultSet = metaData.getTables(null, null, "%", null) + + // Iterate over the tables and print their schema + while (tablesResultSet.next()) { + val tableName = tablesResultSet.getString("TABLE_NAME") + val tableType = tablesResultSet.getString("TABLE_TYPE") + + // Check if the table is a user-created table + if (tableType != null && tableType.equalsIgnoreCase("TABLE")) { + schema += s"Schema for table: $tableName\n" + + // Get the columns' schema information for the current table + val columnsResultSet: ResultSet = metaData.getColumns(null, null, tableName, null) + + // Iterate over the columns and print their schema + while (columnsResultSet.next()) { + val columnName = columnsResultSet.getString("COLUMN_NAME") + val dataType = columnsResultSet.getString("TYPE_NAME") + val columnSize = columnsResultSet.getInt("COLUMN_SIZE") + val nullable = columnsResultSet.getBoolean("NULLABLE") + + schema += s"Column: $columnName, Type: $dataType, Size: $columnSize, Nullable: $nullable\n" + } + columnsResultSet.close() + } + } + // Close the resources + tablesResultSet.close() + + val input = schema + readLine("Enter a query: ") + val systemInfo = "Convert the following Sentence to an SQL query. Return only SQL, no explanation, do not warp it in a mardown code block" val completion = Await.result(service.createChatCompletion( Seq(MessageSpec(ChatRole.System, systemInfo), MessageSpec(ChatRole.User, input)), settings = CreateChatCompletionSettings( @@ -31,15 +65,34 @@ import org.postgresql.util.PSQLException )), Duration.Inf) val query = completion.choices.head.message.content println(query) - try { - val rs = stm.executeQuery(query) - while(rs.next) { - println(rs.getFloat("price")) + if (readLine("Execute the query? [Y/n]: ").toLowerCase != "n") { + try { + val resultSet = stm.executeQuery(query) + val metaData = resultSet.getMetaData + val columnCount = metaData.getColumnCount + val columnNames = (1 to columnCount).map(metaData.getColumnName) + + // Process the query results + while (resultSet.next()) { + // Retrieve data for each column + columnNames.foreach { columnName => + val columnValue = resultSet.getObject(columnName) + print(s"$columnName: $columnValue\t") + } + println() + } + resultSet.close() + } + catch { + case e: PSQLException => { + // ignore error: query has no output + if (e.getSQLState != "02000") { + println(e) + } + } } } - catch { - case e: PSQLException => println(e) - } } + stm.close() conn.close() }