-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(python): Embedding model tuner #1221
base: main
Are you sure you want to change the base?
Conversation
pa.Table.from_pydict(relevant_docs), | ||
save_dir / "relevant_docs.lance", | ||
mode=mode, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:
- Put all 3 parts of the dataset into a single lance table
Probably doable but the current design treats docs as a single unit with any number of query and response pairs possible..
f"model = get_registry().get('sentence-transformers').create(name='./{self.path}')" # noqa | ||
) | ||
|
||
def _wandb_callback(self, score, epoch, steps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future TODO:
- This integration doesn't work. Investigate
res = model.evaluate(ds) | ||
assert res is not None | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future TODO:
- This is an umbrella test. Need more granular ones
model: Any, | ||
trainset: QADataset, | ||
valset: Optional[QADataset] = None, | ||
path: Optional[str] = "~/.lancedb/embeddings/models", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: check conflict with .lance files
llm: BaseLLM, | ||
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL, | ||
num_questions_per_chunk: int = 2, | ||
) -> "QADataset": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept generators
nodes: List["TextChunk"], | ||
queries: Dict[str, str], | ||
relevant_docs: Dict[str, List[str]], | ||
) -> "QADataset": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: test api with WANDS dataset
from abc import ABC, abstractmethod | ||
|
||
|
||
class BaseEmbeddingTuner(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: get rid of this
Solves #1021
Design and usage doc - https://www.notion.so/LanceDB-High-Level-Specs-From-ML-perspective-f9b7470b1e4e4c9e8371ad28b574c185?pvs=4#d6a4f29edf3d4ced9954ab8a913ef9f0
Benchmarks for 5 epochs on 65-35 train test split - https://wandb.ai/cayush/lancedb_finetune?nw=nwusercayush