Skip to content
This repository has been archived by the owner on Jun 24, 2020. It is now read-only.

Commit

Permalink
feat(importer): add data importer class
Browse files Browse the repository at this point in the history
  • Loading branch information
Philippe Cote-Boucher committed Sep 12, 2019
1 parent 476b3aa commit 176b12e
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 8 deletions.
1 change: 1 addition & 0 deletions rasa_addons/importers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from rasa_addons.importers.botfront import BotfrontFileImporter
90 changes: 90 additions & 0 deletions rasa_addons/importers/botfront.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import logging
import os
import copy
from typing import Optional, Text, Union, List, Dict

from rasa import data
from rasa.core.domain import Domain, InvalidDomain
from rasa.core.interpreter import RegexInterpreter, NaturalLanguageInterpreter
from rasa.core.training.structures import StoryGraph
from rasa.core.training.dsl import StoryFileReader
from rasa.importers import utils
from rasa.importers.importer import TrainingDataImporter
from rasa.nlu.training_data import TrainingData
from rasa.utils import io as io_utils

logger = logging.getLogger(__name__)


class BotfrontFileImporter(TrainingDataImporter):

def __init__(
self,
config_paths: Optional[Dict[Text, Text]] = None,
domain_path: Optional[Text] = None,
training_data_path: Optional[Text] = None,
):
# keep only policies in core_config
self.core_config = {'policies': io_utils.read_config_file(
config_paths[list(config_paths.keys())[0]]
)['policies']}
self._stories_path = os.path.join(training_data_path, 'stories.md')

# keep all but policies in nlu_config
self.nlu_config = {}
for lang in config_paths:
self.nlu_config[lang] = io_utils.read_config_file(config_paths[lang])
del self.nlu_config[lang]['policies']
self.nlu_config[lang]['data'] = 'data_for_' + lang # so rasa.nlu.train.train makes the right get_nlu_data call
self.nlu_config[lang]['path'] = os.path.join(training_data_path, 'nlu', '{}.md'.format(lang))

self._domain_path = domain_path

async def get_core_config(self) -> Dict:
return self.core_config

async def get_nlu_config(self, languages = True) -> Dict:
if not isinstance(languages, list):
languages = self.nlu_config.keys()
return {lang: self.nlu_config[lang] if lang in languages else False for lang in self.nlu_config.keys()}

async def get_stories(
self,
interpreter: "NaturalLanguageInterpreter" = RegexInterpreter(),
template_variables: Optional[Dict] = None,
use_e2e: bool = False,
exclusion_percentage: Optional[int] = None,
) -> StoryGraph:

story_steps = await StoryFileReader.read_from_files(
[self._stories_path],
await self.get_domain(),
interpreter,
template_variables,
use_e2e,
exclusion_percentage,
)
return StoryGraph(story_steps)

async def get_nlu_data(self, languages = True) -> Dict[Text, TrainingData]:
if isinstance(languages, str) and languages.startswith('data_for_'):
lang = languages.replace('data_for_', '')
return utils.training_data_from_paths([self.nlu_config[lang]['path']], 'xx')
if not isinstance(languages, list):
languages = self.nlu_config.keys()
return {lang: utils.training_data_from_paths([self.nlu_config[lang]['path']], 'xx')
for lang in languages}

async def get_domain(self) -> Domain:
domain = Domain.empty()
try:
domain = Domain.load(self._domain_path)
domain.check_missing_templates()
except InvalidDomain as e:
logger.warning(
"Loading domain from '{}' failed. Using empty domain. Error: '{}'".format(
self._domain_path, e.message
)
)

return domain
16 changes: 8 additions & 8 deletions rasa_addons/nlu/components/gazette.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class Gazette(Component):
}

def __init__(self,
component_config: Text = None,
gazette: Optional[Dict] = None) -> None:
component_config: Text = None,
gazette: Optional[Dict] = None) -> None:

super(Gazette, self).__init__(component_config)
self.gazette = gazette if gazette else {}
Expand Down Expand Up @@ -63,12 +63,12 @@ def train(

@classmethod
def load(cls,
component_meta: Dict[Text, Any],
model_dir: Text = None,
model_metadata: Metadata = None,
cached_component: Optional['Gazette'] = None,
**kwargs: Any
) -> 'Gazette':
component_meta: Dict[Text, Any],
model_dir: Text = None,
model_metadata: Metadata = None,
cached_component: Optional['Gazette'] = None,
**kwargs: Any
) -> 'Gazette':
from rasa.nlu.utils import read_json_file

td = read_json_file(os.path.join(model_dir, "training_data.json"))
Expand Down

0 comments on commit 176b12e

Please sign in to comment.