Skip to content

Multi-Label Classifier

MLTClassifier is able to classify and label records based on a complex specification. This is useful for certain tasks such as extraction of deep intent from user questions

LLMMultiLabelTextClassifier(name, cred={}, platform=defaults.LLM_PLATFORM, model=defaults.LLM_MODEL)

Bases: AgentEnablersMixin

A bot to take in some text and perform multi-label classification on it

init the bot name: name of the bot cred: credentials object platform: name of the platform backend to use default to openai platform for now will be extended in the future to suuport other platforms

Source code in llmsdk/agents/mltclassifier.py
def __init__(self,
             name,
             cred={},
             platform=defaults.LLM_PLATFORM,
             model=defaults.LLM_MODEL):
    """
    init the bot
    name: name of the bot
    cred: credentials object
    platform: name of the platform backend to use
            default to openai platform for now
            will be extended in the future to suuport other platforms
    """

    start_time = time.time()

    # logging
    self.logger = get_logger()

    # defaults
    self.metadata = {}
    self.sanity_entries = {}
    self.max_llm_tokens = 512

    # name
    self.agent_name = name
    self.agent_type = "query-router"

    # creds
    self.cred = cred
    # LLM params
    self.platform = platform
    self.model = model

    # init the llm and embeddings objects
    self.llm, self.embeddings = self._get_llm_objs(platform=self.platform,
                                                      model=self.model,
                                                      cred=self.cred)

    # note metadata for this agent
    self.metadata = {
        "agent": {
            "name": self.agent_name,
            "platform": self.platform,
            "model": self.model,
        },
        "events": []
    }
    # log that the agent is ready
    duration = time.time() - start_time
    event = self._log_event(agent_events._EVNT_READY, duration)

get_prompt(query, label_spec)

generate a prompt for querying the LLM

Source code in llmsdk/agents/mltclassifier.py
    def get_prompt(self, query, label_spec):
        """
        generate a prompt for querying the LLM
        """
        # construct the prompt template

        defualt_persona = """You are a highly advanced, AI-enabled bot, that performs multi-label text classification"""

        # get the bot's persona
        persona = label_spec.get("persona", defualt_persona)

        # get the bot's specified instruction set
        instructions = label_spec.get("instructions", "")

        # get the synonymns
        syns = []
        for term, synonymns in label_spec.get('synonymns', {}).items():
            syn = f"  - {term}: {', '.join(synonymns)}"
            syns.append(syn)
        synonymns = "\n".join(syns)
        if len(synonymns) > 0:
            synonymns = f"""Here is a set of comma-separated synonymns for various phrases that may be asked about:
{synonymns}"""

        # get the output format
        output_format = label_spec.get('output_format', {})
        output_type = output_format.get('type', 'json')
        output_sample = output_format.get('sample', None)

        if output_sample:
            output_format = f"""Always respond by formatting your response as a {output_type} object EXACTLY as follows:
{output_sample}"""
        else:
            output_format = f"""Always respond by formatting your response as a {output_type} object"""

        # construct the system message
        sys_msg = "\n\n".join([persona, instructions, synonymns, output_format])

        # construct the human message
        human_msg = f"""
Here is the question:
------ BEGIN QUESTION ------
{query}
------- END QUESTION -------

Your response:
"""

        messages = [
            SystemMessage(content=sys_msg),
            HumanMessage(content=human_msg),
        ]

        return messages

label(query, label_spec={})

run a prompt against the LLM using the routing spec

Source code in llmsdk/agents/mltclassifier.py
def label(self, query, label_spec={}):
    """
    run a prompt against the LLM using the routing spec
    """
    start_time = time.time()

    success = True
    response = None

    # construct the prompt to the policy bot
    prompt = self.get_prompt(query, label_spec)

    # run the query
    try:
        if self.platform in ['openai', 'azure']:
            with get_openai_callback() as cb:
                response = self.llm(prompt)
            stats = {
                "total_tokens": cb.total_tokens,
                "prompt_tokens": cb.prompt_tokens,
                "completion_tokens": cb.completion_tokens,
                "total_cost": round(cb.total_cost, 4)
            }
        else:
            response = self.llm(prompt)
            stats = {}

        response = json.loads(response.content)

    except:
        success = False

    # log the event
    params = {
        "query": query,
        "mode": "internal",
        "result": response.copy() if response is not None else None,
        "stats": stats,
    }
    duration = time.time() - start_time
    event = self._log_event(agent_events._EVNT_QUERY, duration, params=params)

    return success, response

query_to_text(agent, query, label_spec, commentary)

Take a query and a label_spec and use the MLTClassifier agent to get a set of labels then use the labels to construct a key use the key to lookup the data in the commentary dict

Source code in llmsdk/agents/mltclassifier.py
def query_to_text(agent, query, label_spec, commentary):
    """
    Take a query and a label_spec and use the MLTClassifier agent to get a set of labels
    then use the labels to construct a key
    use the key to lookup the data in the commentary dict
    """
    # default
    err = "Error"
    text = err

    # first, get the route
    success, result = agent.label(query, label_spec)

    if success:
        # construct the key
        keyparts = []
        keyparts.append(result.get('insight', ''))
        keyparts.append(result.get('cohort', ''))
        keyparts = [k for k in keyparts if k!='']
        key = "_".join(keyparts)

        # lookup
        text = commentary.get(key, {}).get('text', err)

    return text