Skip to content

Commit

Permalink
added table of contents to notebook, tested main.py
Browse files Browse the repository at this point in the history
Signed-off-by: Vinay Raman <[email protected]>
  • Loading branch information
vinay-raman committed Nov 12, 2024
1 parent 8dc0d3f commit d5dc0ae
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 211 deletions.
Empty file.
Empty file.
32 changes: 26 additions & 6 deletions tutorials/nemo-retriever-synthetic-data-generation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,64 @@
# limitations under the License.

import argparse
import importlib
import os
import shutil
from typing import Any, List

from retriever_evalset_generator import RetrieverEvalSetGenerator
from tqdm.dask import TqdmCallback

from nemo_curator import AsyncOpenAIClient, ScoreFilter, Sequential
from nemo_curator.datasets import DocumentDataset
from nemo_curator.filters import AnswerabilityFilter, EasinessFilter
from nemo_curator.modules.config import RetrieverEvalSDGConfig
from nemo_curator.modules.filter import Score, ScoreFilter

config = importlib.import_module(
"tutorials.nemo-retriever-synthetic-data-generation.config.config"
)
retriever_evalset_generator = importlib.import_module(
"tutorials.nemo-retriever-synthetic-data-generation.retriever_evalset_generator"
)


def get_pipeline(args: Any) -> Any:

cfg = RetrieverEvalSDGConfig.from_yaml(args.pipeline_config)
cfg = config.RetrieverEvalSDGConfig.from_yaml(args.pipeline_config)
# update api_key from input args
cfg.api_key = args.api_key

sdg_pipeline = Sequential(
[
RetrieverEvalSetGenerator(cfg),
retriever_evalset_generator.RetrieverEvalSetGenerator(cfg),
]
)
filters = []
if cfg.easiness_filter:
filters.append(
ScoreFilter(
EasinessFilter(cfg),
EasinessFilter(
cfg.base_url,
cfg.api_key,
cfg.easiness_filter,
cfg.percentile,
cfg.truncate,
cfg.batch_size,
),
text_field=["text", "question"],
score_field="easiness_scores",
)
)
if cfg.answerability_filter:
filters.append(
ScoreFilter(
AnswerabilityFilter(cfg),
AnswerabilityFilter(
cfg.base_url,
cfg.api_key,
cfg.answerability_filter,
cfg.answerability_system_prompt,
cfg.answerability_user_prompt_template,
cfg.num_criteria,
),
text_field=["text", "question"],
score_field="answerability_scores",
)
Expand Down
Loading

0 comments on commit d5dc0ae

Please sign in to comment.