如何将json对象直接从python传递到rasa-nlu中进行训练

2024-09-27 23:27:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我用rasa nlu训练数据。根据http://nlu.rasa.ai/python.html中的文档,必须使用以下代码来训练文件demo中存在的数据-rasa.json文件在

from rasa_nlu.converters import load_data
from rasa_nlu.config import RasaNLUConfig
from rasa_nlu.model import Trainer

training_data = load_data('data/examples/rasa/demo-rasa.json')
trainer = Trainer(RasaNLUConfig("sample_configs/config_spacy.json"))
trainer.train(training_data)
model_directory = trainer.persist('./projects/default/')

但是我们如何从json对象中读取数据进行训练呢。在


Tags: 文件数据fromimportconfigjsondatamodel
3条回答

有一种简单的方法可以做到这一点,但是由于RASA的代码文档很差,所以很难找到它。在

您必须按照以下格式创建一个json。在

training_data = {'rasa_nlu_data': {"common_examples": training_examples,
                                   "regex_features": [],
                                   "lookup_tables": [],
                                   "entity_synonyms": []
                                   }}

在这个JSON培训中,示例是一个列表,它应该包含如下所示的数据。在

^{pr2}$

有了这个,你就可以这样训练了:)

from rasa.nlu import config

# Even config can also be loaded from dict like this    
def get_train_config():
    return {'language': 'en',
            'pipeline': [
                {'name': 'WhitespaceTokenizer'},
                {'name': 'ConveRTFeaturizer'},
                {'name': 'EmbeddingIntentClassifier'}
                ],
            'data': None,
            'policies': [
                {'name': 'MemoizationPolicy'},
                {'name': 'KerasPolicy'},
                {'name': 'MappingPolicy'}
                ]
            }

trainer = Trainer(config._load_from_dict(get_train_config()))
interpreter = trainer.train(data)

我做了一个flask应用程序,它从请求体获取JSON对象,而不是从文件中读取它。在

这段代码使用spaCy for entities和sklearn crfsuite转换现有的LUIS json,用于意图识别。在

from flask import Flask, jsonify, request
from flask_cors import CORS
import json, os, msvcrt, psutil, subprocess, datetime

app = Flask(__name__)

CORS(app)

with app.app_context():
    with app.test_request_context():

        #region REST based RASA API
        serverExecutablePID = 0     
        hasAPIStarted = False
        configFileDirectory = "C:\\Code\\RasaAPI\\RASAResources\\config"
        chitChatModel = "ChitChat"
        assetsDirectory = "C:\\Code\\RasaAPI\\RASAResources"

        def createSchema(SchemaPath, dataToBeWritten):
            try:
                    #write LUIS or RASA JSON Schema in json file locking the file to avoid race condition using Python's Windows msvcrt binaries
                    with open(SchemaPath, "w") as SchemaCreationHandle:
                        msvcrt.locking(SchemaCreationHandle.fileno(), msvcrt.LK_LOCK, os.path.getsize(SchemaPath))
                        json.dump(dataToBeWritten, SchemaCreationHandle, indent = 4, sort_keys=False)
                        SchemaCreationHandle.close()

                    #Check if written file actually exists on disk or not
                    doesFileExist = os.path.exists(SchemaPath)                    
                    return doesFileExist

            except Exception as ex:
                return str(ex.args)


        def appendTimeStampToModel(ModelName):
            return ModelName + '_{:%Y%m%d-%H%M%S}.json'.format(datetime.datetime.now())

        def appendTimeStampToConfigSpacy(ModelName):
            return ModelName + '_config_spacy_{:%Y%m%d-%H%M%S}.json'.format(datetime.datetime.now())

        def createConfigSpacy(ModelName, DataPath, ConfigSpacyPath, TrainedModelsPath, LogDataPath):
            try:
                    with open(ConfigSpacyPath, "w") as configSpacyFileHandle:
                        msvcrt.locking(configSpacyFileHandle.fileno(), msvcrt.LK_LOCK, os.path.getsize(ConfigSpacyPath))
                        configDataToBeWritten = dict({
                        "project": ModelName,
                        "data": DataPath,
                        "path": TrainedModelsPath,
                        "response_log": LogDataPath,
                        "log_level": "INFO",
                        "max_training_processes": 1,
                        "pipeline": "spacy_sklearn",
                        "language": "en",
                        "emulate": "luis",
                        "cors_origins": ["*"],
                        "aws_endpoint_url": None,
                        "token": None,
                        "num_threads": 2,
                        "port": 5000
                        })
                        json.dump(configDataToBeWritten, configSpacyFileHandle, indent = 4, sort_keys=False)

                    return os.path.getsize(ConfigSpacyPath) > 0

            except Exception as ex:
                return str(ex.args)

        def TrainRASA(configFilePath):
            try:  
                trainingString = 'start /wait python -m rasa_nlu.train -c ' + '\"' + os.path.normpath(configFilePath) + '\"'
                returnCode = subprocess.call(trainingString, shell = True)
                return returnCode

            except Exception as ex:
                return str(ex.args)

        def StartRASAServer(configFileDirectory, ModelName):
            #region Server starting logic
            try:
                global hasAPIStarted
                global serverExecutablePID
                #1) for finding which is the most recent config_spacy
                root, dirs, files = next(os.walk(os.path.normpath(configFileDirectory)))

                configFiles = [configFile for configFile in files if ModelName in configFile]
                configFiles.sort(key = str.lower, reverse = True)
                mostRecentConfigSpacy = os.path.join(configFileDirectory, configFiles[0])

                serverStartingString = 'start /wait python -m rasa_nlu.server -c ' + '\"' + os.path.normpath(mostRecentConfigSpacy) + '\"'

                serverProcess = subprocess.Popen(serverStartingString, shell = True)
                serverExecutablePID = serverProcess.pid

                pingReturnCode = 1
                while(pingReturnCode):
                    pingReturnCode = os.system("netstat -na | findstr /i 5000")
                if(pingReturnCode == 0):
                    hasAPIStarted = True

                return pingReturnCode

            except Exception as ex:
                return jsonify({"message": "Failed because: " + str(ex.args) , "success": False})
            #endregion

        def KillProcessWindow(hasAPIStarted, serverExecutablePID):
            if(hasAPIStarted == True and serverExecutablePID != 0):
                me = psutil.Process(serverExecutablePID)
                for child in me.children():
                    child.kill()


        @app.route('/api/TrainRASA', methods = ['POST'])
        def TrainRASAServer():
            try:
                #get request body of POST request
                postedJSONData = json.loads(request.data, strict = False)

                if postedJSONData["data"] is not None:
                    print("Valid data")
                    #region JSON file building logic
                    modelName = postedJSONData["modelName"]
                    modelNameWithExtension = appendTimeStampToModel(modelName)
                    schemaPath = os.path.join(assetsDirectory, "data", modelNameWithExtension)
                    print(createSchema(schemaPath, postedJSONData["data"]))
                    #endregion

                    #region config file creation logic
                    configFilePath = os.path.join(assetsDirectory, "config", appendTimeStampToConfigSpacy(modelName))
                    logsDirectory = os.path.join(assetsDirectory, "logs")
                    trainedModelDirectory = os.path.join(assetsDirectory, "models")
                    configFileCreated = createConfigSpacy(modelName, schemaPath, configFilePath, trainedModelDirectory, logsDirectory)
                    #endregion

                    if(configFileCreated == True):
                        #region Training RASA NLU with schema
                        TrainingReturnCode = TrainRASA(configFilePath)
                        #endregion

                        if(TrainingReturnCode == 0):
                            return jsonify({"message": "Successfully trained RASA NLU with modelname: " + modelName, "success": True})
                            # KillProcessWindow(hasAPIStarted, serverExecutablePID)
                            # serverStartingReturnCode = StartRASAServer(configFileDirectory, modelName)
                            # #endregion

                            # if serverStartingReturnCode == 0:                    
                            #     return jsonify({"message": "Successfully started RASA server on port 5000", "success": True})

                            # elif serverStartingReturnCode is None:
                            #     return jsonify({"message": "Could not start RASA server, request timed out", "success": False})

                        else:
                            return jsonify({"message": "Soemthing wrong happened while training RASA NLU!", "success": False})

                    else:
                        return jsonify({"message": "Could not create config file for RASA NLU", "success": False})

                #throw exception if request body is empty
                return jsonify({"message": "Please enter some JSON, JSON seems to be empty", "success": False})

            except Exception as ex:
                return jsonify({"Reason": "Failed because" + str(ex.args), "success": False})

        @app.route('/api/StopRASAServer', methods = ['GET'])
        def StopRASAServer():
            try:
                global serverExecutablePID

                if(serverExecutablePID != 0 or serverExecutablePID != None):
                    me = psutil.Process(serverExecutablePID)
                    for child in me.children():
                        child.kill()
                    return jsonify({"message": "Server stopped....", "success": True})
            except Exception as ex:
                 return jsonify({"message": "Something went wrong while shutting down the server because: " + str(ex.args), "success": True})

        if __name__ == "__main__":
            StartRASAServer(configFileDirectory, chitChatModel)
            app.run(debug=False, threaded = True, host='0.0.0.0', port = 5050)

如果您查看^{}的实现,它将执行两个步骤:

  1. 猜猜文件格式
  2. 使用适当的加载方法加载文件

最简单的解决方案是将json对象写入file或StringIO对象。在

或者,您可以选择所需的特定加载函数,例如^{},并将读取的文件与它分开。对于这个例子,您可能只需要获取整个函数并删除^{}行。在

我有点惊讶地发现,目前还没有办法读取已经加载的json对象。如果您决定让函数适应这种情况,您可以考虑为此编写一个pull请求。在

相关问题 更多 >

    热门问题