Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
*.DS_Store
__pycache__/
__pycache__/
benchmark/data/*context/*
test/*
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ openai = "^1.3.5"
anthropic = "^0.7.4"
datasets = "^2.15.0"
plotly = "^5.18.0"
vllm = "^0.2.6"

[tool.poetry.scripts]
refchecker-cli = "refchecker.cli:main"
Expand Down
156 changes: 95 additions & 61 deletions refchecker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from argparse import ArgumentParser, RawTextHelpFormatter
from tqdm import tqdm

from .extractor import Claude2Extractor, GPT4Extractor
from .extractor import Claude2Extractor, GPT4Extractor, MistralExtractor
from .checker import Claude2Checker, GPT4Checker, NLIChecker
from .retriever import GoogleRetriever
from .aggregator import strict_agg, soft_agg, major_agg
Expand All @@ -12,78 +12,106 @@
def get_args():
parser = ArgumentParser(formatter_class=RawTextHelpFormatter)
parser.add_argument(
"mode", nargs="?", choices=["extract", "check", "extract-check"],
help="extract: Extract triplets from provided responses.\n"
"check: Check whether the provided triplets are factual.\n"
"extract-check: Extract triplets and check whether they are factual."
"mode",
nargs="?",
choices=["extract", "check", "extract-check"],
help="extract: Extract triplets from provided responses.\ncheck: Check whether the provided triplets are factual.\nextract-check: Extract triplets and check whether they are factual.",
)
parser.add_argument(
"--input_path", type=str, required=True,
help="Input path to the json file."
"--input_path", type=str, required=True, help="Input path to the json file."
)
parser.add_argument(
"--output_path", type=str, required=True,
help="Output path to the result json file."
"--output_path",
type=str,
required=True,
help="Output path to the result json file.",
)
parser.add_argument(
"--cache_dir", type=str, default="./.cache",
help="Path to the cache directory. Default: ./.cache"
"--cache_dir",
type=str,
default="./.cache",
help="Path to the cache directory. Default: ./.cache",
)
parser.add_argument(
'--extractor_name', type=str, default="claude2",
choices=["gpt4", "claude2"],
help="Model used for extracting triplets. Default: claude2."
"--extractor_name",
type=str,
default="claude2",
choices=["gpt4", "claude2", "mistral", "mixtral"],
help="Model used for extracting triplets. Default: claude2.",
)
parser.add_argument(
'--extractor_max_new_tokens', type=int, default=500,
help="Max generated tokens of the extractor, set a larger value for longer documents. Default: 500"
"--extractor_max_new_tokens",
type=int,
default=500,
help="Max generated tokens of the extractor, set a larger value for longer documents. Default: 500",
)
parser.add_argument(
"--checker_name", type=str, default="claude2",
"--checker_name",
type=str,
default="claude2",
choices=["gpt4", "claude2", "nli"],
help="Model used for checking whether the triplets are factual. "
"Default: claude2."
help="Model used for checking whether the triplets are factual. Default: claude2.",
)
parser.add_argument(
"--retriever_name", type=str, default="google", choices=["google"],
help="Model used for retrieving reference (currently only google is"
" supported). Default: google."
"--retriever_name",
type=str,
default="google",
choices=["google"],
help="Model used for retrieving reference (currently only google is supported). Default: google.",
)
parser.add_argument(
"--aggregator_name", type=str, default="soft",
"--aggregator_name",
type=str,
default="soft",
choices=["strict", "soft", "major"],
help="Aggregator used for aggregating the results from multiple "
"triplets. Default: soft.\n"
"* strict: If any of the triplets is Contradiction, the response"
" is Contradiction.\nIf all of the triplets are Entailment, the "
"response is Entailment. Otherwise, the\nresponse is Neutral.\n"
"* soft: The ratio of each category is calculated.\n"
"* major: The category with the most votes is selected."
help="Aggregator used for aggregating the results from multiple triplets. Default: soft.\n* strict: If any of the triplets is Contradiction, the response is Contradiction.\nIf all of the triplets are Entailment, the response is Entailment. Otherwise, the\nresponse is Neutral.\n* soft: The ratio of each category is calculated.\n* major: The category with the most votes is selected.",
)
parser.add_argument(
"--openai_key", type=str, default="",
help="Path to the openai api key file. Required if openAI models are"
" used."
"--openai_key",
type=str,
default="",
help="Path to the openai api key file. Required if openAI models are" " used.",
)
parser.add_argument(
"--anthropic_key", type=str, default="",
help="Path to the Anthropic api key file. Required if the Anthropic "
"Claude2 api is used."
"--anthropic_key",
type=str,
default="",
help="Path to the Anthropic api key file. Required if the Anthropic Claude2 api is used.",
)
parser.add_argument(
"--aws_bedrock_region", type=str, default="",
help="AWS region where the Amazon Bedrock api is deployed. Required if "
"the Amazon Bedrock api is used."
"--aws_bedrock_region",
type=str,
default="",
help="AWS region where the Amazon Bedrock api is deployed. Required if the Amazon Bedrock api is used.",
)
parser.add_argument(
"--use_retrieval", action="store_true",
help="Whether to use retrieval to find the reference for checking. "
"Required if the reference\nfield in input data is not provided."
"--use_retrieval",
action="store_true",
help="Whether to use retrieval to find the reference for checking. Required if the reference\nfield in input data is not provided.",
)
parser.add_argument(
"--serper_api_key", type=str, default="",
help="Path to the serper api key file. Required if the google retriever"
" is used."
"--serper_api_key",
type=str,
default="",
help="Path to the serper api key file. Required if the google retriever is used.",
)
parser.add_argument(
"--local_llm_checkpoint_path",
type=str,
default=None,
help="Specify the local LLM checkpoint path if you use one other than the official release. By default, the official release of the specified LLM is used.",
)
parser.add_argument(
"--extractor_ngpus",
type=int,
default=None,
help="Specify the number of GPUs you want to use in launching a local model. By default, 1 is used for small models and up to all are used for larger ones.",
)
parser.add_argument(
"--nli_device",
type=int,
default=None,
help="Specify the device in using NLI model as checker. By default uses 0.",
)

return parser.parse_args()
Expand Down Expand Up @@ -124,21 +152,29 @@ def extract(args):
extractor = Claude2Extractor()
elif args.extractor_name == "gpt4":
extractor = GPT4Extractor()
elif args.extractor_name in ["mixtral", "mistral"]:
extractor = MistralExtractor(
model_path=args.local_llm_checkpoint_path,
use_gpu_num=args.extractor_ngpus,
model_name=args.extractor_name,
)
else:
raise NotImplementedError

# load data
with open(args.input_path, "r") as fp:
input_data = json.load(fp)

# extract triplets
print('Extracting')
print("Extracting")
output_data = []
for item in tqdm(input_data):
assert "response" in item, "response field is required"
response = item["response"]
question = item.get("question", None)
triplets = extractor.extract_claim_triplets(response, question, max_new_tokens=args.extractor_max_new_tokens)
triplets = extractor.extract_claim_triplets(
response, question, max_new_tokens=args.extractor_max_new_tokens
)
out_item = {**item, **{"triplets": triplets}}
output_data.append(out_item)
with open(args.output_path, "w") as fp:
Expand All @@ -152,17 +188,17 @@ def check(args):
elif args.checker_name == "gpt4":
checker = GPT4Checker()
elif args.checker_name == "nli":
checker = NLIChecker()
checker = NLIChecker(device=args.nli_device)
else:
raise NotImplementedError

retriever = None
if args.use_retrieval:
if args.retriever_name == "google":
retriever = GoogleRetriever(args.cache_dir)
else:
raise NotImplementedError

if args.aggregator_name == "strict":
agg_fn = strict_agg
elif args.aggregator_name == "soft":
Expand All @@ -171,13 +207,13 @@ def check(args):
agg_fn = major_agg
else:
raise NotImplementedError

# load data
with open(args.input_path, "r") as fp:
input_data = json.load(fp)

# check triplets
print('Checking')
print("Checking")
output_data = []
for item in tqdm(input_data):
assert "triplets" in item, "triplets field is required"
Expand All @@ -186,21 +222,19 @@ def check(args):
reference = retriever.retrieve(item["response"])
item["reference"] = reference
else:
assert "reference" in item, \
"reference field is required if retriever is not used."
assert (
"reference" in item
), "reference field is required if retriever is not used."
reference = item["reference"]
question = item.get("question", None)
results = [
checker.check(t, reference, question=question)
for t in triplets
]
results = [checker.check(t, reference, question=question) for t in triplets]
agg_results = agg_fn(results)
out_item = {
**item,
**{
"Y": agg_results,
"ys": results,
}
},
}
output_data.append(out_item)
with open(args.output_path, "w") as fp:
Expand Down
1 change: 1 addition & 0 deletions refchecker/extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .claude2_extractor import Claude2Extractor
from .gpt4_extractor import GPT4Extractor
from .mistral_extractor import MistralExtractor
Loading