-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft dspy implementation #9097
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice start! Added some comments for how we can get this ready to merge 🚀
@@ -0,0 +1,9 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should remove this altogether. What is it used for?
# Input of llm is a dataframe- this creates a place for both the input | ||
# and the output to be saved | ||
cursor.execute(''' | ||
CREATE TABLE IF NOT EXISTS llm_io_data ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to create this table in our db module. You can do this by creating a new class:
class LLMData(Base):
__tablename__ = "llm_data"
id = Column(Integer, primary_key=True)
...
and then writing a DB migration so we can easily upgrade/downgrade the DB. Check out the Alembic docs (the Python library we are using) and some examples for how to do this
output TEXT | ||
)''') | ||
cursor.execute(''' | ||
CREATE TABLE IF NOT EXISTS retrieval_data ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, we need to remove this SQL and create a new table in our db module with a migration.
def update_retrieval_index(self, input_data, output_data): | ||
cursor = self.con.cursor() | ||
cursor.execute(''' | ||
INSERT INTO retrieval_data (question, answer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, you should never execute raw SQL like this outside of a dedicated DAO (data access object). This should look something like:
def update_retrieval_index(self, input_data, output_data):
self.retrieval_data_controller.add(input_data, output_data)
# Could also add this functionality to our existing agents_controller
# self.agents_controller.add_retrieval_data(input_data, output_data)
@@ -322,9 +367,115 @@ def _invoke_agent_executor_with_prompt(agent_executor, prompt): | |||
|
|||
return pred_df | |||
|
|||
def initialize_database(self): | |||
# Connect to an sqlite database | |||
self.con = sqlite3.connect('llm_data.db') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After we create tables in our DB module & access them with DAOs (i.e. our agents_controller
, model_controller
, etc), we should remove this
# Extract the relevant response part if necessary | ||
return response['output'], context | ||
except: | ||
return "Error in processing your request.", context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this be returned directly to the user? If so, could consider improving the error message. If not, it's fine
def on_self_improvement_start(self, context: Dict[str, Any]) -> Any: | ||
'''Run when the agent's self-improvement process starts.''' | ||
self.logger.debug('Self-improvement process started with context:') | ||
self.logger.debug(str(context)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible that some context
are unable to be printed since the value (type Any
) doesn't support printing. I ran into this issue myself when logging. Could just simply wrap it in a try
for now:
try:
self.logger.debug(str(context))
def on_self_improvement_end(self, result: Dict[str, Any]) -> Any: | ||
'''Run when the agent's self-improvement process ends.''' | ||
self.logger.debug('Self-improvement process ended with result:') | ||
self.logger.debug(str(result)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above for wrapping in a try
(just the result
logging)
@@ -72,10 +91,14 @@ def __init__( | |||
# if True, the target column name does not have to be specified at creation time. | |||
self.generative = True | |||
self.default_agent_tools = DEFAULT_AGENT_TOOLS | |||
self.initialize_database() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We won't need this after we update the PR to use controllers & our defined DB tables (see below comments)
self.log_callback_handler = log_callback_handler | ||
self.langfuse_callback_handler = langfuse_callback_handler | ||
if self.log_callback_handler is None: | ||
self.log_callback_handler = LogCallbackHandler(logger) | ||
self.use_dspy = kwargs.get('use_dspy', False) # option to use DSPy or not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be checked in our create
and predict
methods instead, from the args
passed in those
Description
Draft of adding DSPy self improvement feature to the langchain handler file. This is currently an optional functionality to the file and uses the sqlite database to save and access past llm prompts and responses.