Skip to content

Commit e5a0db6

Browse files
committed
Add NER model variant with required fields
In order to have an NER model that is simpler for internal regex/CFG representations, add an NER variant that requires all fields and does not include a default value. In particular, this makes it possible to evaluate a version of NER for Outlines and provides an additional point of comparison for other libraries.
1 parent 2576774 commit e5a0db6

File tree

4 files changed

+45
-6
lines changed

4 files changed

+45
-6
lines changed

config.yaml

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,18 @@ OutlinesFramework:
149149
retries: 0 # Oulines transformers has no retry parameter
150150
source_data_pickle_path: "data/multilabel_classification.pkl"
151151
# sample_rows: 2
152+
- task: "ner_required_fields"
153+
n_runs: 10
154+
init_kwargs:
155+
prompt: "Extract and resolve a list of entities from the following text: {text}"
156+
# switch to prompt with JSON schema when supported:
157+
#prompt: "Extract and resolve a list of entities from the following text: {text}.\nRespond in the following json schema: {json_schema}:\n"
158+
llm_model: "unsloth/llama-3-8b-Instruct-bnb-4bit"
159+
llm_model_family: "transformers"
160+
retries: 0
161+
source_data_pickle_path: "data/ner.pkl"
162+
max_length: 4096
163+
#sample_rows: 2
152164
- task: "synthetic_data_generation"
153165
n_runs: 100
154166
init_kwargs:
@@ -177,6 +189,16 @@ LMFormatEnforcerFramework:
177189
source_data_pickle_path: "data/ner.pkl"
178190
max_length: 4096
179191
# sample_rows: 2
192+
- task: "ner_required_fields"
193+
n_runs: 10
194+
init_kwargs:
195+
prompt: "Extract and resolve a list of entities from the following text: {text}.\nRespond in the following json schema: {json_schema}:\n"
196+
llm_model: "unsloth/llama-3-8b-Instruct-bnb-4bit"
197+
llm_model_family: "transformers"
198+
retries: 0
199+
source_data_pickle_path: "data/ner.pkl"
200+
max_length: 4096
201+
# sample_rows: 2
180202
- task: "synthetic_data_generation"
181203
n_runs: 100
182204
init_kwargs:
@@ -234,4 +256,4 @@ ModelsmithFramework:
234256
# llm_model_family: "transformers"
235257
# retries: 0 # Oulines transformers has no retry parameter
236258
# source_data_pickle_path: "data/ner.pkl"
237-
# # sample_rows: 2
259+
# # sample_rows: 2

data_sources/data_models.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum
33
from typing import Any, Optional, Type
44

5-
from pydantic import BaseModel, create_model, field_validator
5+
from pydantic import BaseModel, Field, create_model, field_validator
66
from pydantic_core import PydanticUndefined
77

88

@@ -24,6 +24,14 @@ def ner_model(ner_entities):
2424
return NER
2525

2626

27+
def ner_required_fields_model(ner_entities):
28+
fields = {name: (list[str], Field(description="")) for name in ner_entities}
29+
30+
NERRequiredFields = create_model("NERRequiredFields", **fields)
31+
32+
return NERRequiredFields
33+
34+
2735
def synthetic_data_generation_model():
2836
class UserAddress(BaseModel):
2937
street: str

frameworks/base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from data_sources.data_models import (
1313
multilabel_classification_model,
1414
ner_model,
15+
ner_required_fields_model,
1516
synthetic_data_generation_model,
1617
)
1718

@@ -82,7 +83,7 @@ def experiment(
8283

8384
def experiment_decorator(func):
8485
def wrapper(*args, **kwargs):
85-
allowed_tasks = ["multilabel_classification", "ner", "synthetic_data_generation"]
86+
allowed_tasks = ["multilabel_classification", "ner", "ner_required_fields", "synthetic_data_generation"]
8687
if task not in allowed_tasks:
8788
raise ValueError(
8889
f"{task} is not allowed. Allowed values are {allowed_tasks}"
@@ -119,7 +120,7 @@ def wrapper(*args, **kwargs):
119120
framework_metrics = {
120121
"accuracy": accurate / num_successful if num_successful else 0
121122
}
122-
elif task == "ner":
123+
elif task in ("ner", "ner_required_fields"):
123124
framework_metrics = []
124125
for response in responses:
125126
framework_metrics.append(calculate_metrics(expected_response, response))
@@ -191,6 +192,14 @@ def __init__(self, *args, **kwargs) -> None:
191192

192193
self.response_model = ner_model(self.entities)
193194

195+
elif self.task == "ner_required_fields":
196+
# Identify the entities
197+
self.entities = list(
198+
{key for d in self.source_data["labels"] for key in d.keys()}
199+
)
200+
201+
self.response_model = ner_required_fields_model(self.entities)
202+
194203
elif self.task == "synthetic_data_generation":
195204
self.response_model = synthetic_data_generation_model()
196205

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def generate_results(
9595
task: str = "multilabel_classification",
9696
):
9797

98-
allowed_tasks = ["multilabel_classification", "ner", "synthetic_data_generation"]
98+
allowed_tasks = ["multilabel_classification", "ner", "ner_required_fields", "synthetic_data_generation"]
9999
if task not in allowed_tasks:
100100
raise ValueError(f"{task} is not allowed. Allowed values are {allowed_tasks}")
101101

@@ -127,7 +127,7 @@ def generate_results(
127127
logger.info(f"Latencies:\n{metrics.latency_metric(latencies, 95)}")
128128

129129
# NER Micro Metrics
130-
if task == "ner":
130+
if task in ("ner", "ner_required_fields"):
131131
micro_metrics_df = metrics.ner_micro_metrics(results)
132132
logger.info(f"NER Micro Metrics:\n{micro_metrics_df}")
133133

0 commit comments

Comments
 (0)