Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A bunch of improvements for the classification skill #50

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

chemeris
Copy link

@chemeris chemeris commented Jan 6, 2024

Thank you for an excellent library. I ran it against our dataset of 10k social media posts to detect their sentiment and classify it into a set of topics. This set of patches is what it took for me to get it working - from minor fixes to improvements in the learning process.

I have a bunch of ideas on how to improve learning performance with more advanced learning strategies - I'd be happy to discuss this if there is interest in implementing advanced learning strategies.

PS Let me know if you want me to break this into smaller PRs.

1. Make number of rows to print a configurable parameter (default to 5).
2. Add parameter to enable index column printing. This implementation
   prints index of the original data frame if the passed data frame is
   derived.
Otherwise we don't know whether it improved the accuracy or not.

We should eventually introduce smarter learning strategies. E.g.
simple ones like not accepting changes that make accuracy worse.
Or complex ones with genetic algorithms like in FunSearch.
Accuracy threshold was ignored and unused originally which made
training quite difficult in practical scenarios.
When providing feedback to the model, mention which output is wrong.
Otherwise the model doesn't have enough information which of
the outputs is correct/incorrect.
1. Phrase the prompt in more imperative manner liked by GPT models.
2. Instruct the teacher model to avoid unnecessary rephrasing of
   the prompt. With GPT-4 this makes it to make a lot less unnecessary
   changes. When a skill has multiple outputs, each skill output rewrite
   also changes wording of all the other outputs unnecessary distorting
   and degrading their performance. This phrasing significantly reduces
   such distortion but doesn't remove it completely. Running training
   cycles on each skill output separately solves this completely but
   is much much slower. Another potential solutions (I haven't tried
   it yet) is to collect feedback for all outputs and apply it all
   in a single go. More testing with real data is needed here.
@@ -207,9 +207,9 @@ def get_feedback(
[gt_pred_match.rename("match"), gt], axis=1
)
pred_feedback[pred_column] = match_concat.apply(
lambda row: "Prediction is correct."
lambda row: f"Prediction for {gt_column} is correct."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the reason for this addition? the initial idea was to use a single column at a time, so pointing out a specific column name might be not necessary - but I'm probably missing your idea

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Classification (and Transform) skills support multiple outputs and I used this to classify each social media post into multiple categories (each one as a True/False field).

You're correct that on each step we evaluate only one output. But the base prompt doesn't mention which of the outputs we evaluate at the step. This patch is the easiest way I found to make sure the model understands that the feedback is related to a specific output.

If there is only one output, we can simply say "Prediction is correct." as it used to be.

Here is how the template from TransformSkill.improve(). Note that it doesn't mention anything about the output name:

"""
## Current prompt
{self.instructions}

## Examples
{examples}

Summarize your analysis about incorrect predictions and suggest changes to the prompt."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your idea about referencing specific columns to help LLM make the correct assessment. However, the column name defined there doesn't contain any signal for LLM, example from the tests:

Prediction for gt_0 is incorrect. Correct answer:     0 0 0   1 1 1   1 5 1   
  "1 1 1" 

"gt_0" keyword is not presented in input prompt which consists of the string "Input: ... Output: ...". In this case, I'd better create a string like
"Prediction for the field "Output" is incorrect"
assuming there can be multiple outputs.
Let me know if it makes sense.
Happy to merge your PR as soon as we have all tests passed. Thank you.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.
I just need to figure out how to get the field name 🤔

@niklub
Copy link
Contributor

niklub commented Jan 7, 2024

Hi, @chemeris ! thanks for your great contribution!

I'd be happy to discuss this if there is interest in implementing advanced learning strategies.

Absolutely, it would be very helpful to have different learning strategies and reasoning path, and give user the options. Feel free to open the github issue where we can discuss the solutions and reference it in PR hereafter

@@ -207,9 +207,9 @@ def get_feedback(
[gt_pred_match.rename("match"), gt], axis=1
)
pred_feedback[pred_column] = match_concat.apply(
lambda row: "Prediction is correct."
lambda row: f"Prediction for {gt_column} is correct."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your idea about referencing specific columns to help LLM make the correct assessment. However, the column name defined there doesn't contain any signal for LLM, example from the tests:

Prediction for gt_0 is incorrect. Correct answer:     0 0 0   1 1 1   1 5 1   
  "1 1 1" 

"gt_0" keyword is not presented in input prompt which consists of the string "Input: ... Output: ...". In this case, I'd better create a string like
"Prediction for the field "Output" is incorrect"
assuming there can be multiple outputs.
Let me know if it makes sense.
Happy to merge your PR as soon as we have all tests passed. Thank you.

@robot-ci-heartex robot-ci-heartex marked this pull request as draft April 5, 2024 07:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants