basic version
This commit is contained in:
@ -9,7 +9,7 @@ import io.cequence.openaiscala.domain.ModelId
|
|||||||
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
|
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
|
||||||
import io.cequence.openaiscala.domain.MessageSpec
|
import io.cequence.openaiscala.domain.MessageSpec
|
||||||
import io.cequence.openaiscala.domain.ChatRole
|
import io.cequence.openaiscala.domain.ChatRole
|
||||||
import java.sql.{Connection, DriverManager, ResultSet}
|
import java.sql.{Connection, DriverManager, ResultSet, DatabaseMetaData}
|
||||||
import scala.io.StdIn.readLine
|
import scala.io.StdIn.readLine
|
||||||
import org.postgresql.util.PSQLException
|
import org.postgresql.util.PSQLException
|
||||||
|
|
||||||
@ -21,9 +21,43 @@ import org.postgresql.util.PSQLException
|
|||||||
val con_str = "jdbc:postgresql://localhost:5432/chatSql"
|
val con_str = "jdbc:postgresql://localhost:5432/chatSql"
|
||||||
val conn = DriverManager.getConnection(con_str)
|
val conn = DriverManager.getConnection(con_str)
|
||||||
val stm = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
|
val stm = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
|
||||||
|
var schema = ""
|
||||||
while (true) {
|
while (true) {
|
||||||
val input = readLine("Enter a query: ")
|
// Get the DatabaseMetaData
|
||||||
val systemInfo = "Convert the following Sentence to an SQL query. Return only SQL"
|
val metaData: DatabaseMetaData = conn.getMetaData
|
||||||
|
|
||||||
|
// Get the tables' schema information
|
||||||
|
val tablesResultSet: ResultSet = metaData.getTables(null, null, "%", null)
|
||||||
|
|
||||||
|
// Iterate over the tables and print their schema
|
||||||
|
while (tablesResultSet.next()) {
|
||||||
|
val tableName = tablesResultSet.getString("TABLE_NAME")
|
||||||
|
val tableType = tablesResultSet.getString("TABLE_TYPE")
|
||||||
|
|
||||||
|
// Check if the table is a user-created table
|
||||||
|
if (tableType != null && tableType.equalsIgnoreCase("TABLE")) {
|
||||||
|
schema += s"Schema for table: $tableName\n"
|
||||||
|
|
||||||
|
// Get the columns' schema information for the current table
|
||||||
|
val columnsResultSet: ResultSet = metaData.getColumns(null, null, tableName, null)
|
||||||
|
|
||||||
|
// Iterate over the columns and print their schema
|
||||||
|
while (columnsResultSet.next()) {
|
||||||
|
val columnName = columnsResultSet.getString("COLUMN_NAME")
|
||||||
|
val dataType = columnsResultSet.getString("TYPE_NAME")
|
||||||
|
val columnSize = columnsResultSet.getInt("COLUMN_SIZE")
|
||||||
|
val nullable = columnsResultSet.getBoolean("NULLABLE")
|
||||||
|
|
||||||
|
schema += s"Column: $columnName, Type: $dataType, Size: $columnSize, Nullable: $nullable\n"
|
||||||
|
}
|
||||||
|
columnsResultSet.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Close the resources
|
||||||
|
tablesResultSet.close()
|
||||||
|
|
||||||
|
val input = schema + readLine("Enter a query: ")
|
||||||
|
val systemInfo = "Convert the following Sentence to an SQL query. Return only SQL, no explanation, do not warp it in a mardown code block"
|
||||||
val completion = Await.result(service.createChatCompletion(
|
val completion = Await.result(service.createChatCompletion(
|
||||||
Seq(MessageSpec(ChatRole.System, systemInfo), MessageSpec(ChatRole.User, input)),
|
Seq(MessageSpec(ChatRole.System, systemInfo), MessageSpec(ChatRole.User, input)),
|
||||||
settings = CreateChatCompletionSettings(
|
settings = CreateChatCompletionSettings(
|
||||||
@ -31,15 +65,34 @@ import org.postgresql.util.PSQLException
|
|||||||
)), Duration.Inf)
|
)), Duration.Inf)
|
||||||
val query = completion.choices.head.message.content
|
val query = completion.choices.head.message.content
|
||||||
println(query)
|
println(query)
|
||||||
try {
|
if (readLine("Execute the query? [Y/n]: ").toLowerCase != "n") {
|
||||||
val rs = stm.executeQuery(query)
|
try {
|
||||||
while(rs.next) {
|
val resultSet = stm.executeQuery(query)
|
||||||
println(rs.getFloat("price"))
|
val metaData = resultSet.getMetaData
|
||||||
|
val columnCount = metaData.getColumnCount
|
||||||
|
val columnNames = (1 to columnCount).map(metaData.getColumnName)
|
||||||
|
|
||||||
|
// Process the query results
|
||||||
|
while (resultSet.next()) {
|
||||||
|
// Retrieve data for each column
|
||||||
|
columnNames.foreach { columnName =>
|
||||||
|
val columnValue = resultSet.getObject(columnName)
|
||||||
|
print(s"$columnName: $columnValue\t")
|
||||||
|
}
|
||||||
|
println()
|
||||||
|
}
|
||||||
|
resultSet.close()
|
||||||
|
}
|
||||||
|
catch {
|
||||||
|
case e: PSQLException => {
|
||||||
|
// ignore error: query has no output
|
||||||
|
if (e.getSQLState != "02000") {
|
||||||
|
println(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch {
|
|
||||||
case e: PSQLException => println(e)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
stm.close()
|
||||||
conn.close()
|
conn.close()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user