diff --git a/.github/workflows/test_learnware_with_pip.yaml b/.github/workflows/test_learnware_with_pip.yaml index c59a63fa..48fccf2c 100644 --- a/.github/workflows/test_learnware_with_pip.yaml +++ b/.github/workflows/test_learnware_with_pip.yaml @@ -13,7 +13,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-20.04] + os: [ubuntu-22.04] python-version: [3.9] steps: diff --git a/.github/workflows/test_learnware_with_source.yaml b/.github/workflows/test_learnware_with_source.yaml index 13b6ac94..0e4bd7ad 100644 --- a/.github/workflows/test_learnware_with_source.yaml +++ b/.github/workflows/test_learnware_with_source.yaml @@ -13,7 +13,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-20.04] + os: [ubuntu-22.04] python-version: [3.9] steps: @@ -50,4 +50,4 @@ jobs: - name: Test workflow run: | - conda run -n learnware python -m pytest tests/test_workflow/test_hetero_workflow.py \ No newline at end of file + conda run -n learnware python -m pytest tests/test_workflow/test_hetero_workflow.py diff --git a/.gitignore b/.gitignore index f361d742..891813c1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,9 @@ dist/ *.pkl *.hd5 *.csv +!/examples/dataset_llm_workflow/model_performance/medical.csv +!/examples/dataset_llm_workflow/model_performance/math.csv +!/examples/dataset_llm_workflow/model_performance/finance.csv *.out *.html *.dot @@ -45,4 +48,5 @@ learnware_pool/ PFS/ data/ examples/results/ -examples/*/results/ \ No newline at end of file +examples/*/results/ +examples/*/user_specs/ \ No newline at end of file diff --git a/CHANGES.rst b/CHANGES.rst index ae692d3c..afd27c87 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,7 +1,16 @@ Changelog ========= -Here you can see the full list of changes between ``learnware`` release. +Here you can see the full list of changes between ``learnware`` releases. -Version 0.3.2 +Version 0.4.0.post1 (2025-05-25) --------------- -This is the first public release of ``learnware`` package. +* Bugfix release. + +Version 0.4.0 (2025-05-20) +--------------- +* Added support for 7B level language model learnwares. +* Added two new specifications, specifically designed for language model learnwares. + +Version 0.3.2 (2024-01-24) +--------------- +* First public release of ``learnware`` package. \ No newline at end of file diff --git a/README.md b/README.md index e1ac07d1..e9287334 100644 --- a/README.md +++ b/README.md @@ -392,6 +392,51 @@ The results are depicted in the following table and figure. Similarly, even when +# LLM Experimental Results (New) + +This section refers to Section 4 of our paper [*Learnware of Language Models: Specialized Small Language Models Can Do Big*](https://arxiv.org/abs/2505.13425). We simulate a learnware system comprising approximately 100 learnwares of specialized SLMs with 8B parameters, fine-tuned across finance, healthcare, and mathematics domains. + +Experimental results demonstrate promising performance: by selecting one suitable learnware for each task-specific inference, the system outperforms the base SLMs on all benchmarks. Compared to LLMs, the system outperforms Qwen1.5-110B, Qwen2.5-72B, and Llama3.1-70B-Instruct by at least 14% in finance domain tasks. Additionally, it surpasses Flan-PaLM-540B (ranked 7th on the [Open Medical LLM Leaderboard](https://huggingface.co/spaces/openlifescienceai/open_medical_llm_leaderboard)) in medical domain tasks. + +The figure and table below show the performance value in finance scenario. + +
+ +
+ +
+ +| User | Qwen2.5-7B | Llama3.1-8B-Instruct | Llama3.1-8B | Qwen1.5-110B | Qwen2.5-72B | Llama3.1-70B-Instruct | Random | Learnware | Best-single | Oracle | +|:-------------------------|:-------------|:-----------------------|:--------------|:---------------|:--------------|:------------------------|:---------|:------------|:--------------|:---------| +| australian | 43.17 | 44.6 | 43.17 | 43.17 | 43.17 | 47.48 | 44.45 | 56.83 | 42.21 | 56.83 | +| cra_lendingclub | 80.82 | 76.33 | 57.34 | 80.82 | 47.01 | 53.07 | 81.52 | 92.07 | 80.82 | 92.07 | +| fiqasa | 38.3 | 40.43 | 56.17 | 63.4 | 64.26 | 68.51 | 46.53 | 76.38 | 32.06 | 76.38 | +| fpb | 76.08 | 32.78 | 30.72 | 70.72 | 78.35 | 78.04 | 67.95 | 84.25 | 77.73 | 84.25 | +| german | 65.0 | 49.5 | 66.0 | 66.0 | 66.5 | 43.5 | 51.5 | 67.06 | 65.33 | 67.06 | +| headlines | 74.81 | 59.95 | 59.95 | 62.96 | 77.84 | 77.53 | 72.43 | 95.61 | 95.61 | 95.61 | +| ner | 21.75 | 0.62 | 9.01 | 17.89 | 9.36 | 9.52 | 24.99 | 52.79 | 23.98 | 52.79 | +| sm_acl | 51.1 | 51.4 | 51.34 | 49.3 | 51.56 | 49.38 | 51.42 | 52.82 | 50.71 | 53.63 | +| sm_bigdata | 55.3 | 55.57 | 52.79 | 51.02 | 50.27 | 47.76 | 53.86 | 52.4 | 55.52 | 55.88 | +| sm_cikm | 58.44 | 54.24 | 54.07 | 44.01 | 58.27 | 47.86 | 55.89 | 55.99 | 57.98 | 58.52 | +| causal20_sc | 65.14 | 88.48 | 79.45 | 83.75 | 76.17 | 87.16 | 74.71 | 84.17 | 88.61 | 88.61 | +| finarg_ecc_arc | 64.78 | 46.67 | 60.0 | 62.32 | 63.04 | 44.64 | 62.27 | 64.31 | 57.87 | 68.36 | +| finarg_ecc_auc | 48.3 | 51.81 | 49.85 | 55.01 | 61.71 | 65.02 | 52.08 | 58.08 | 48.68 | 58.08 | +| fomc | 60.48 | 29.44 | 34.68 | 58.47 | 57.66 | 66.13 | 56.05 | 62.7 | 61.36 | 62.7 | +| ma | 79.2 | 56.4 | 51.0 | 81.4 | 84.6 | 83.2 | 73.64 | 79.81 | 79.27 | 79.81 | +| mlesg | 35.67 | 32.67 | 20.0 | 34.67 | 38.67 | 42.33 | 31.99 | 33.42 | 38.33 | 38.33 | +| multifin_en | 60.99 | 31.32 | 28.39 | 65.38 | 63.55 | 68.5 | 54.96 | 63.46 | 58.61 | 63.46 | +| Avg. | 57.61 | 47.19 | 47.29 | 58.25 | 58.35 | 57.63 | 56.25 | 66.6 | 59.69 | 67.79 | +| Avg. rank | 5.94 | 7.35 | 7.82 | 5.94 | 4.71 | 5.24 | 6.47 | 2.88 | 5.47 | 1.65 | +| Learnware (win/tie/loss) | 13/0/4 | 15/0/2 | 16/0/1 | 14/0/3 | 12/0/5 | 11/0/6 | 16/0/1 | nan | 12/1/4 | 0/11/6 | +| Oracle (win/tie/loss) | 17/0/0 | 17/0/0 | 17/0/0 | 15/0/2 | 13/0/4 | 12/0/5 | 17/0/0 | 6/11/0 | 14/3/0 | nan | + +
+ +Our system demonstrates strong performance across financial tasks, achieving the highest average score among all methods, delivering an nearly 14\% improvement compared with the best large-scale model Qwen2.5-72B. It ranks first strategies utilizing specialized SLMs except Oracle in 13 out of 17 tasks, identifies the optimal learnware (tied with Oracle) on 11 and outperforms all contenders in 8. + +These results shows that our system can match or surpass large-scale models with over 70B parameters under the Task-Level evaluation setting, while requiring only the memory for models under 8B efficiently. + +**For more scenarios (medical and math) and details, please see [here](./examples/dataset_llm_workflow/README.md).** # Citation diff --git a/README_zh.md b/README_zh.md index 75bf0bf5..f413f081 100644 --- a/README_zh.md +++ b/README_zh.md @@ -398,6 +398,51 @@ feature_augment_predict_y = reuse_feature_augment.predict(user_data=test_x) +# LLM 实验结果(新增) + +本节对应于我们的论文 [*Learnware of Language Models: Specialized Small Language Models Can Do Big*](https://arxiv.org/abs/2505.13425) 的第 4 部分。我们模拟建立了一个含有约 100 个 8B 级别专用 SLM 学件的学件基座系统,涵盖金融、医疗和数学三个领域。 + +实验结果展现了我们系统的良好性能:通过为每个专用领域任务选择一个合适的学件,该系统在所有场景的基准测试中均优于基座 SLM 以及基线算法;与 70B 以上的大参数规模语言模型相比,该系统在大幅减少显存占用的情况下,在金融领域中的性能表现至少比 Qwen1.5-110B、Qwen2.5-72B 和 Llama3.1-70B-Instruct 高出 14%。此外,在医疗领域中,它超越了 Flan-PaLM-540B(在 [Open Medical LLM Leaderboard](https://huggingface.co/spaces/openlifescienceai/open_medical_llm_leaderboard) 上排名第七)。 + +下图和表格展示了不同方法或模型在金融评估场景上的性能分数: + +
+ +
+ +
+ +| User | Qwen2.5-7B | Llama3.1-8B-Instruct | Llama3.1-8B | Qwen1.5-110B | Qwen2.5-72B | Llama3.1-70B-Instruct | Random | Learnware | Best-single | Oracle | +|:-------------------------|:-------------|:-----------------------|:--------------|:---------------|:--------------|:------------------------|:---------|:------------|:--------------|:---------| +| australian | 43.17 | 44.6 | 43.17 | 43.17 | 43.17 | 47.48 | 44.45 | 56.83 | 42.21 | 56.83 | +| cra_lendingclub | 80.82 | 76.33 | 57.34 | 80.82 | 47.01 | 53.07 | 81.52 | 92.07 | 80.82 | 92.07 | +| fiqasa | 38.3 | 40.43 | 56.17 | 63.4 | 64.26 | 68.51 | 46.53 | 76.38 | 32.06 | 76.38 | +| fpb | 76.08 | 32.78 | 30.72 | 70.72 | 78.35 | 78.04 | 67.95 | 84.25 | 77.73 | 84.25 | +| german | 65.0 | 49.5 | 66.0 | 66.0 | 66.5 | 43.5 | 51.5 | 67.06 | 65.33 | 67.06 | +| headlines | 74.81 | 59.95 | 59.95 | 62.96 | 77.84 | 77.53 | 72.43 | 95.61 | 95.61 | 95.61 | +| ner | 21.75 | 0.62 | 9.01 | 17.89 | 9.36 | 9.52 | 24.99 | 52.79 | 23.98 | 52.79 | +| sm_acl | 51.1 | 51.4 | 51.34 | 49.3 | 51.56 | 49.38 | 51.42 | 52.82 | 50.71 | 53.63 | +| sm_bigdata | 55.3 | 55.57 | 52.79 | 51.02 | 50.27 | 47.76 | 53.86 | 52.4 | 55.52 | 55.88 | +| sm_cikm | 58.44 | 54.24 | 54.07 | 44.01 | 58.27 | 47.86 | 55.89 | 55.99 | 57.98 | 58.52 | +| causal20_sc | 65.14 | 88.48 | 79.45 | 83.75 | 76.17 | 87.16 | 74.71 | 84.17 | 88.61 | 88.61 | +| finarg_ecc_arc | 64.78 | 46.67 | 60.0 | 62.32 | 63.04 | 44.64 | 62.27 | 64.31 | 57.87 | 68.36 | +| finarg_ecc_auc | 48.3 | 51.81 | 49.85 | 55.01 | 61.71 | 65.02 | 52.08 | 58.08 | 48.68 | 58.08 | +| fomc | 60.48 | 29.44 | 34.68 | 58.47 | 57.66 | 66.13 | 56.05 | 62.7 | 61.36 | 62.7 | +| ma | 79.2 | 56.4 | 51.0 | 81.4 | 84.6 | 83.2 | 73.64 | 79.81 | 79.27 | 79.81 | +| mlesg | 35.67 | 32.67 | 20.0 | 34.67 | 38.67 | 42.33 | 31.99 | 33.42 | 38.33 | 38.33 | +| multifin_en | 60.99 | 31.32 | 28.39 | 65.38 | 63.55 | 68.5 | 54.96 | 63.46 | 58.61 | 63.46 | +| Avg. | 57.61 | 47.19 | 47.29 | 58.25 | 58.35 | 57.63 | 56.25 | 66.6 | 59.69 | 67.79 | +| Avg. rank | 5.94 | 7.35 | 7.82 | 5.94 | 4.71 | 5.24 | 6.47 | 2.88 | 5.47 | 1.65 | +| Learnware (win/tie/loss) | 13/0/4 | 15/0/2 | 16/0/1 | 14/0/3 | 12/0/5 | 11/0/6 | 16/0/1 | nan | 12/1/4 | 0/11/6 | +| Oracle (win/tie/loss) | 17/0/0 | 17/0/0 | 17/0/0 | 15/0/2 | 13/0/4 | 12/0/5 | 17/0/0 | 6/11/0 | 14/3/0 | nan | + +
+ +我们的系统在金融任务中表现出色,在所有方法中取得了最高的平均得分,比表现最好的大参数规模模型 Qwen2.5-72B 性能提高了14\%。在 17 个任务中,有 13 个任务的得分高于除 Oracle 外的专用 SLM 模型选择方法,在11个任务上查搜到了最优学件(性能表现与Oracle一致),在 8 个任务上战胜了所有其他方法或模型。 + +上述结果表明,在任务级评估的实验设定下,仅查搜使用参数规模在 8B 级别的小型语言模型,学件基座系统的整体表现可以媲美甚至超越参数规模在 70B 以上的大模型,并大幅降低模型推理时的显存占用。 + +**更多场景(医学和数学)上的实验结果和详细信息,请参阅[此处](./examples/dataset_llm_workflow/README.md)。** # 引用 diff --git a/docs/_static/img/llm-finance.svg b/docs/_static/img/llm-finance.svg new file mode 100644 index 00000000..53c53f28 --- /dev/null +++ b/docs/_static/img/llm-finance.svg @@ -0,0 +1,3650 @@ + + + + + + + + 2025-05-24T23:34:13.998171 + image/svg+xml + + + Matplotlib v3.9.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/_static/img/llm-math.svg b/docs/_static/img/llm-math.svg new file mode 100644 index 00000000..17e18311 --- /dev/null +++ b/docs/_static/img/llm-math.svg @@ -0,0 +1,3643 @@ + + + + + + + + 2025-05-24T23:34:45.863194 + image/svg+xml + + + Matplotlib v3.9.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/_static/img/llm-medical.svg b/docs/_static/img/llm-medical.svg new file mode 100644 index 00000000..44a274c8 --- /dev/null +++ b/docs/_static/img/llm-medical.svg @@ -0,0 +1,2721 @@ + + + + + + + + 2025-05-24T23:33:44.823111 + image/svg+xml + + + Matplotlib v3.9.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/components/market.rst b/docs/components/market.rst index 896b5ec0..3c349b98 100644 --- a/docs/components/market.rst +++ b/docs/components/market.rst @@ -65,7 +65,7 @@ Easy market is a basic realization of the learnware market. It consists of ``Eas ``EasyOrganizer`` mainly has the following methods to store learnwares, which is an easy way to organize learnwares. - **reload_market**: Reload the learnware market when the server restarts and return a flag indicating whether the market is reloaded successfully. -- **add_learnware**: Add a learnware with ``learnware_id``, ``semantic_spec`` and model files in ``zip_path`` into the market. Return the ``learnware_id`` and ``learnwere_status``. The ``learnwere_status`` is set to ``check_status`` if it is provided. Otherwise, the ``checker`` will be called to generate the ``learnwere_status``. +- **add_learnware**: Add a learnware with ``learnware_id``, ``semantic_spec`` and model files in ``zip_path`` into the market. Return the ``learnware_id`` and ``learnware_status``. The ``learnware_status`` is set to ``check_status`` if it is provided. Otherwise, the ``checker`` will be called to generate the ``learnware_status``. - **delete_learnware**: Delete the learnware with ``id`` from the market and return a flag indicating whether the deletion is successful. - **update_learnware**: Update the learnware's ``zip_path``, ``semantic_spec``, ``check_status``. If None, the corresponding item is not updated. Return a flag indicating whether it passed the ``checker``. - **get_learnwares**: Similar to **get_learnware_ids**, but return list of learnwares instead of ids. @@ -148,7 +148,7 @@ As more learnwares are submitted, this heterogeneous engine will continuously up - **reload_market**: Reloads the heterogeneous engine if there is one. Otherwise, initialize an engine with default configurations. Returns a flag indicating whether the market is reloaded successfully. - **reset**: Resets the heterogeneous market with specific settings regarding the heterogeneous engine such as ``auto_update``, ``auto_update_limit`` and ``training_args`` configurations. -- **add_learnware**: Add a learnware into the market, meanwhile generating ``HeteroMapTableSpecification`` for the learnware using the heterogeneous engine. The engine's update process will be triggered if ``auto_update`` is set to True and the number of learnwares in the market with ``USABLE_LEARNWARE`` status exceeds ``auto_update_limit``. Return the ``learnware_id`` and ``learnwere_status``. +- **add_learnware**: Add a learnware into the market, meanwhile generating ``HeteroMapTableSpecification`` for the learnware using the heterogeneous engine. The engine's update process will be triggered if ``auto_update`` is set to True and the number of learnwares in the market with ``USABLE_LEARNWARE`` status exceeds ``auto_update_limit``. Return the ``learnware_id`` and ``learnware_status``. - **delete_learnware**: Removes the learnware with ``id`` from the market and also removes its new specification if there is one. Return a flag of whether the deletion is successful. - **update_learnware**: Update the learnware's ``zip_path``, ``semantic_spec``, ``check_status`` and its new specification if there is one. Return a flag indicating whether it passed the ``checker``. - **generate_hetero_map_spec**: Generate ``HeteroMapTableSpecification`` for users based on the user's statistical specification provided in ``user_info``. diff --git a/examples/dataset_image_workflow/workflow.py b/examples/dataset_image_workflow/workflow.py index 685a47fa..2cb29ef1 100644 --- a/examples/dataset_image_workflow/workflow.py +++ b/examples/dataset_image_workflow/workflow.py @@ -18,7 +18,7 @@ from learnware.market import BaseUserInfo, instantiate_learnware_market from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser from learnware.specification import generate_stat_spec -from learnware.tests.benchmarks import LearnwareBenchmark +from learnware.tests.benchmarks import LearnwareBenchmarkManager from learnware.utils import choose_device logger = get_module_logger("image_workflow", level="INFO") @@ -57,7 +57,7 @@ def _plot_labeled_peformance_curves(self, all_user_curves_data): def _prepare_market(self, rebuild=False): client = LearnwareClient() - self.image_benchmark = LearnwareBenchmark().get_benchmark(image_benchmark_config) + self.image_benchmark = LearnwareBenchmarkManager().get_benchmark(image_benchmark_config) self.image_market = instantiate_learnware_market(market_id=self.image_benchmark.name, rebuild=rebuild) self.user_semantic = client.get_semantic_specification(self.image_benchmark.learnware_ids[0]) self.user_semantic["Name"]["Values"] = "" diff --git a/examples/dataset_llm_workflow/README.md b/examples/dataset_llm_workflow/README.md new file mode 100644 index 00000000..dd86e694 --- /dev/null +++ b/examples/dataset_llm_workflow/README.md @@ -0,0 +1,151 @@ +# LLM Dataset Workflow Example + +## Introduction + +This workflow refers to Section 4 of our paper [*Learnware of Language Models: Specialized Small Language Models Can Do Big*](https://arxiv.org/abs/2505.13425). We simulate a learnware system comprising approximately 100 learnwares of specialized SLMs with 8B parameters, fine-tuned across finance, healthcare, and mathematics domains. + +We first train multiple models under different configurations by SFT on different datasets using LoRA. Qwen2.5-7B, Llama3.1-8B, Llama3.1-8B-Instruct are our base models. Then we generate specifications for each model and apply a identification algorithm to select the most suitable learnware based on user task requirements. The identified learnware is then evaluated on the corresponding task under the **Task-Level** evaluation setting using EleutherAI's [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). + +We compare Learnware against several contenders, including: +- Ways to utilize specialized SLM(s). This contains a baseline algorithm, Random learnware selection, and two oracle-style strategies with access to the full evaluation results of all candidate models, the Best-single model and Oracle. Best-single refers to the model with the highest average score among the learnware candidates, and Oracle is the optimal performance of utilizing the candidate SLM learnwares by choosing one model for one task, which selects the best performing model on each user task. +- Base models used for fine-tuning. +- Well-known large language models (LLMs) with over 70B parameters. + +We do not distinguish between different models fine-tuned with the same instruction dataset, so if our method select a learnware for solving a given task, the performance is actually calculated by the average of all the models with the selected instruction dataset. + +## Run the code + +Since the evaluation of LLM is a time-consuming process, we provide our evaluation results of all models in a table to help you quickly get the final system performance. + +Run the following command to get results using the performance table of all models in medical/math/finance scenario (skip evaluation). **We recommend you to run these.** + +```bash +python workflow.py llm_example medical +python workflow.py llm_example math +python workflow.py llm_example finance +``` + +Run the following command to obtain results for medical, mathematical, and financial scenarios (including evaluation). In the medical scenario, it takes 3-4 hours to get the final results on one A100 GPU. For math and finance scenario, the process is significantly more time-consuming and requires at least four A100 GPUs. + +```bash +python workflow.py llm_example medical --skip_eval False +python workflow.py llm_example math --skip_eval False +python workflow.py llm_example finance --skip_eval False +``` + +Following [FinBen](https://github.com/The-FinAI/PIXIU), for evaluation in finance scenario, you need to first copy the folder ```extra_tasks/flare``` into the ```tasks``` directory within the installation path of ```lm_eval```. For example, run the following command: + +```bash +cp -r extra_tasks/flare ~/anaconda3/envs/{env_name}/lib/python3.11/site-packages/lm_eval/tasks/ +``` + +## Results + +### Finance + +The figure and table below show the performance value of different methods or language models in finance scenario. + +
+ +
+ +
+ +| User | Qwen2.5-7B | Llama3.1-8B-Instruct | Llama3.1-8B | Qwen1.5-110B | Qwen2.5-72B | Llama3.1-70B-Instruct | Random | Learnware | Best-single | Oracle | +|:-------------------------|:-------------|:-----------------------|:--------------|:---------------|:--------------|:------------------------|:---------|:------------|:--------------|:---------| +| australian | 43.17 | 44.6 | 43.17 | 43.17 | 43.17 | 47.48 | 44.45 | 56.83 | 42.21 | 56.83 | +| cra_lendingclub | 80.82 | 76.33 | 57.34 | 80.82 | 47.01 | 53.07 | 81.52 | 92.07 | 80.82 | 92.07 | +| fiqasa | 38.3 | 40.43 | 56.17 | 63.4 | 64.26 | 68.51 | 46.53 | 76.38 | 32.06 | 76.38 | +| fpb | 76.08 | 32.78 | 30.72 | 70.72 | 78.35 | 78.04 | 67.95 | 84.25 | 77.73 | 84.25 | +| german | 65.0 | 49.5 | 66.0 | 66.0 | 66.5 | 43.5 | 51.5 | 67.06 | 65.33 | 67.06 | +| headlines | 74.81 | 59.95 | 59.95 | 62.96 | 77.84 | 77.53 | 72.43 | 95.61 | 95.61 | 95.61 | +| ner | 21.75 | 0.62 | 9.01 | 17.89 | 9.36 | 9.52 | 24.99 | 52.79 | 23.98 | 52.79 | +| sm_acl | 51.1 | 51.4 | 51.34 | 49.3 | 51.56 | 49.38 | 51.42 | 52.82 | 50.71 | 53.63 | +| sm_bigdata | 55.3 | 55.57 | 52.79 | 51.02 | 50.27 | 47.76 | 53.86 | 52.4 | 55.52 | 55.88 | +| sm_cikm | 58.44 | 54.24 | 54.07 | 44.01 | 58.27 | 47.86 | 55.89 | 55.99 | 57.98 | 58.52 | +| causal20_sc | 65.14 | 88.48 | 79.45 | 83.75 | 76.17 | 87.16 | 74.71 | 84.17 | 88.61 | 88.61 | +| finarg_ecc_arc | 64.78 | 46.67 | 60.0 | 62.32 | 63.04 | 44.64 | 62.27 | 64.31 | 57.87 | 68.36 | +| finarg_ecc_auc | 48.3 | 51.81 | 49.85 | 55.01 | 61.71 | 65.02 | 52.08 | 58.08 | 48.68 | 58.08 | +| fomc | 60.48 | 29.44 | 34.68 | 58.47 | 57.66 | 66.13 | 56.05 | 62.7 | 61.36 | 62.7 | +| ma | 79.2 | 56.4 | 51.0 | 81.4 | 84.6 | 83.2 | 73.64 | 79.81 | 79.27 | 79.81 | +| mlesg | 35.67 | 32.67 | 20.0 | 34.67 | 38.67 | 42.33 | 31.99 | 33.42 | 38.33 | 38.33 | +| multifin_en | 60.99 | 31.32 | 28.39 | 65.38 | 63.55 | 68.5 | 54.96 | 63.46 | 58.61 | 63.46 | +| Avg. | 57.61 | 47.19 | 47.29 | 58.25 | 58.35 | 57.63 | 56.25 | 66.6 | 59.69 | 67.79 | +| Avg. rank | 5.94 | 7.35 | 7.82 | 5.94 | 4.71 | 5.24 | 6.47 | 2.88 | 5.47 | 1.65 | +| Learnware (win/tie/loss) | 13/0/4 | 15/0/2 | 16/0/1 | 14/0/3 | 12/0/5 | 11/0/6 | 16/0/1 | nan | 12/1/4 | 0/11/6 | +| Oracle (win/tie/loss) | 17/0/0 | 17/0/0 | 17/0/0 | 15/0/2 | 13/0/4 | 12/0/5 | 17/0/0 | 6/11/0 | 14/3/0 | nan | + +
+ +Our system demonstrates strong performance across financial tasks, achieving the highest average score among all methods, delivering an nearly 14\% improvement compared with the best large-scale model Qwen2.5-72B. It ranks first strategies utilizing specialized SLMs except Oracle in 13 out of 17 tasks, identifies the optimal learnware (tied with Oracle) on 11 and outperforms all contenders in 8. + +These results shows that our system can match or surpass large-scale models with over 70B parameters under the Task-Level evaluation setting, while requiring only the memory for models under 8B efficiently. + +### Medical + +The figure and table below show the performance value of different methods or language models in medical scenario. + +
+ +
+ +
+ +| User | Qwen2.5-7B | Flan-PaLM-540B | Random | Learnware | Best-single | Oracle | +|:-------------------------|:-------------|:-----------------|:---------|:------------|:--------------|:---------| +| medmcqa | 59.93 | 57.6 | 60.2 | 62.49 | 62.49 | 62.49 | +| medqa_4options | 64.18 | 67.6 | 63.74 | 65.59 | 64.81 | 65.59 | +| anatomy | 71.85 | 63.7 | 71.33 | 71.85 | 70.37 | 72.96 | +| clinical_knowledge | 77.36 | 80.4 | 78.21 | 78.87 | 78.49 | 79.25 | +| college_biology | 82.64 | 88.9 | 84.34 | 85.42 | 84.03 | 86.11 | +| college_medicine | 69.36 | 76.3 | 69.02 | 69.36 | 68.79 | 69.94 | +| medical_genetics | 87.0 | 75.0 | 86.95 | 87.0 | 89.0 | 89.0 | +| professional_medicine | 78.68 | 83.8 | 77.37 | 79.78 | 78.68 | 79.78 | +| pubmedqa | 75.2 | 79.0 | 75.67 | 75.8 | 76.8 | 76.8 | +| Avg. | 74.02 | 74.7 | 74.09 | 75.13 | 74.83 | 75.77 | +| Avg. rank | 4.44 | 2.67 | 4.89 | 2.56 | 3.56 | 1.67 | +| Learnware (win/tie/loss) | 6/3/0 | 3/0/6 | 9/0/0 | nan | 6/1/2 | 0/3/6 | +| Oracle (win/tie/loss) | 9/0/0 | 3/0/6 | 9/0/0 | 6/3/0 | 6/3/0 | nan | + +
+ +As shown, Our system achieves the highest average score across 9 tasks, even surpassing the large-scale model Flan-PaLM-540B. This demonstrates that by leveraging multiple models with fewer than 8B parameters, our system can outperform a single large-scale model in task-specific scenarios. Among SLM utilization strategies, Learnware performs best in 7 out of 9 tasks, tied with Oracle in 6. + +Furthermore, the fact that our system surpasses Best-single highlights that its effectiveness comes not from a single exceptionally strong model but from its specification design, identification mechanism and the collective strength of all candidate models. + +### Math + +The figure and table below show the performance value of different methods or language models in math scenario. + +
+ +
+ +
+ +| User | Qwen2.5-7B | Qwen1.5-110B | Random | Learnware | Best-single | Oracle | +|:------------------------------|:-------------|:---------------|:---------|:------------|:--------------|:---------| +| agieval_aqua_rat | 41.73 | 38.98 | 40.09 | 38.98 | 41.33 | 41.73 | +| agieval_gaokao_mathcloze | 16.95 | 38.14 | 11.72 | 17.8 | 13.14 | 17.8 | +| agieval_gaokao_mathqa | 49.86 | 77.78 | 50.35 | 51.57 | 51.0 | 53.42 | +| agieval_math | 19.8 | 19.3 | 20.15 | 20.6 | 18.5 | 28.4 | +| agieval_sat_math | 55.91 | 57.27 | 55.3 | 57.27 | 57.5 | 57.5 | +| cmmlu_college_mathematics | 45.71 | 47.62 | 49.36 | 52.38 | 48.58 | 52.38 | +| cmmlu_elementary_mathematics | 65.65 | 77.83 | 64.49 | 66.96 | 65.0 | 67.18 | +| cmmlu_high_school_mathematics | 61.59 | 77.44 | 62.5 | 60.98 | 64.32 | 64.63 | +| gsm8k | 84.08 | 84.91 | 80.79 | 84.15 | 83.92 | 84.15 | +| mathqa | 43.32 | 48.07 | 41.51 | 41.41 | 46.28 | 46.28 | +| mgsm_native_cot_zh | 66.4 | 68.8 | 67.64 | 73.6 | 68.8 | 73.6 | +| minerva_math | 40.16 | 47.9 | 37.4 | 36.48 | 41.23 | 45.12 | +| abstract_algebra | 54.0 | 53.0 | 53.83 | 56.0 | 52.0 | 56.0 | +| college_mathematics | 53.0 | 52.0 | 53.61 | 53.0 | 53.5 | 58.0 | +| elementary_mathematics | 72.75 | 78.84 | 73.63 | 75.13 | 73.02 | 75.13 | +| high_school_mathematics | 55.93 | 60.0 | 55.21 | 55.56 | 55.19 | 56.86 | +| Avg. | 51.68 | 57.99 | 51.1 | 52.62 | 52.08 | 54.89 | +| Avg. rank | 4.31 | 2.56 | 4.56 | 3.19 | 4.0 | 1.56 | +| Learnware (win/tie/loss) | 10/1/5 | 5/2/9 | 11/0/5 | nan | 10/0/6 | 0/6/10 | +| Oracle (win/tie/loss) | 15/1/0 | 7/0/9 | 16/0/0 | 10/6/0 | 14/2/0 | nan | + +
+ +Our system achieves optimal identification performance (tied with Oracle) in 10 out of 16 tasks and even outperforms all other contenders in 5. However, the large-scale model achieves the highest average score and even beats Oracle (which denotes the optimal performance using one of our 8B-level models). This is likely due to their strong mathematical reasoning abilities that lack in smaller models, rather than a shortcoming of our method, as evidenced by the minimal difference in the "win/tie/loss" of Learnware and Oracle on Qwen1.5-110B. diff --git a/examples/dataset_llm_workflow/benchmark/__init__.py b/examples/dataset_llm_workflow/benchmark/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/dataset_llm_workflow/benchmark/base.py b/examples/dataset_llm_workflow/benchmark/base.py new file mode 100644 index 00000000..d6a8f8ed --- /dev/null +++ b/examples/dataset_llm_workflow/benchmark/base.py @@ -0,0 +1,72 @@ +from .config import ( + LEARNWARE_MATH, + LEARNWARE_MED, + USER_MED, + USER_MATH, + LEARNWARE_FIN, + USER_FIN, + LEARNWARE_MED_IDS, + LEARNWARE_MATH_IDS, + LEARNWARE_FIN_IDS, +) +from .utils import prepare_train_data, prepare_test_data +from datasets import Dataset +from typing import List, Tuple + + +class Benchmark: + def __init__(self, name: str): + self.name = name + self.set_datasets(name) + + def get_benchmark_name(self): + return self.name + + def set_datasets(self, name: str): + if name == "medical": + self.learnware_dict = LEARNWARE_MED + self.learnware_ids = LEARNWARE_MED_IDS + self.user_dict = USER_MED + elif name == "math": + self.learnware_dict = LEARNWARE_MATH + self.learnware_ids = LEARNWARE_MATH_IDS + self.user_dict = USER_MATH + elif name == "finance": + self.learnware_dict = LEARNWARE_FIN + self.learnware_ids = LEARNWARE_FIN_IDS + self.user_dict = USER_FIN + else: + raise NotImplementedError("other benchmarks are not implemented") + + def get_learnware_ids(self) -> List[str]: + return self.learnware_ids + + def get_learnware_data(self, dataset_name) -> List[str]: + train_dataset, val_dataset = prepare_train_data(self.learnware_dict[dataset_name]) + train_data, val_data = train_dataset["text"], val_dataset["text"] + return train_data, val_data + + def get_learnware_dataset(self, dataset_name) -> Tuple[Dataset, Dataset]: + train_dataset, val_dataset = prepare_train_data(self.learnware_dict[dataset_name]) + return train_dataset, val_dataset + + def get_user_data(self, dataset_name) -> List[str]: + test_dataset = prepare_test_data(self.user_dict[dataset_name]) + test_data = test_dataset["text"] + return test_data + + def get_user_dataset(self, dataset_name) -> Dataset: + test_dataset = prepare_test_data(self.user_dict[dataset_name]) + return test_dataset + + def get_learnwares(self): + return self.learnware_dict + + def get_users(self): + return self.user_dict + + def get_learnware_names(self) -> List[str]: + return list(self.learnware_dict.keys()) + + def get_user_names(self) -> List[str]: + return list(self.user_dict.keys()) diff --git a/examples/dataset_llm_workflow/benchmark/config.py b/examples/dataset_llm_workflow/benchmark/config.py new file mode 100644 index 00000000..3f4aa0b1 --- /dev/null +++ b/examples/dataset_llm_workflow/benchmark/config.py @@ -0,0 +1,199 @@ +LEARNWARE_MATH = { + "MWP-Instruct": "Macropodus/MWP-Instruct", + "school_math_0.25M": "BelleGroup/school_math_0.25M", + "MathInstruct": "TIGER-Lab/MathInstruct", + "MetaMathQA": "meta-math/MetaMathQA", + "orca-math-word-problems-200k": "microsoft/orca-math-word-problems-200k", + "Arithmo-Data": "akjindal53244/Arithmo-Data", + "MATH_train": "ScalableMath/MATH_train-cleaned_processed", + "MetaMath-GSM240K": "fxmeng/MetaMath-GSM240K", + "GSM8K_zh": "meta-math/GSM8K_zh", +} + +LEARNWARE_MED = { + "AlpaCare": "lavita/AlpaCare-MedInstruct-52k", + "ChatDoctor": "lavita/ChatDoctor-HealthCareMagic-100k", + "medalpaca_cleaned": "medalpaca/medical_meadow_wikidoc,medalpaca/medical_meadow_medical_flashcards,medalpaca/medical_meadow_wikidoc_patient_information,medalpaca/medical_meadow_pubmed_causal,medalpaca/medical_meadow_mediqa,medalpaca/medical_meadow_health_advice", + "medqa_train": "medalpaca/medical_meadow_medqa", + "pubmed_causal": "medalpaca/medical_meadow_pubmed_causal", + "medmcqa_train": "chenhaodev/medmcqa_instruct", + "medqa_train&pubmed_causal": "medalpaca/medical_meadow_medqa,medalpaca/medical_meadow_pubmed_causal", + "AlpaCare&ChatDoctor": "LinhDuong/chatdoctor-5k,lavita/ChatDoctor-HealthCareMagic-100k,lavita/AlpaCare-MedInstruct-52k", + "medalpaca_cleaned&AlpaCare&ChatDoctor": "medalpaca/medical_meadow_wikidoc,medalpaca/medical_meadow_medical_flashcards,medalpaca/medical_meadow_wikidoc_patient_information,medalpaca/medical_meadow_pubmed_causal,medalpaca/medical_meadow_mediqa,medalpaca/medical_meadow_health_advice,LinhDuong/chatdoctor-5k,lavita/ChatDoctor-HealthCareMagic-100k,lavita/AlpaCare-MedInstruct-52k", + "medqa_train&medmcqa_train": "medalpaca/medical_meadow_medqa,chenhaodev/medmcqa_instruct", +} + +LEARNWARE_FIN = { + "australian": "ChanceFocus/flare-australian", + "cra_lendingclub": "ChanceFocus/cra-lendingclub", + "fiqasa": "ChanceFocus/flare-fiqasa", + "fpb": "ChanceFocus/en-fpb", + "german": "ChanceFocus/flare-german", + "headlines": "ChanceFocus/flare-headlines", + "ner": "ChanceFocus/flare-ner", + "sm_acl": "ChanceFocus/flare-sm-acl", + "sm_bigdata": "TheFinAI/en-forecasting-bigdata", + "sm_cikm": "ChanceFocus/flare-sm-cikm", +} + +USER_MED = { + "medmcqa": "openlifescienceai/medmcqa", + "medqa_4options": "GBaker/MedQA-USMLE-4-options-hf", + "anatomy": "hails/mmlu_no_train,anatomy", + "clinical_knowledge": "hails/mmlu_no_train,clinical_knowledge", + "college_biology": "hails/mmlu_no_train,college_biology", + "college_medicine": "hails/mmlu_no_train,college_medicine", + "medical_genetics": "hails/mmlu_no_train,medical_genetics", + "professional_medicine": "hails/mmlu_no_train,professional_medicine", + "pubmedqa": "bigbio/pubmed_qa,pubmed_qa_labeled_fold0_source", +} + +USER_MATH = { + "agieval_aqua_rat": "hails/agieval-aqua-rat", + "agieval_gaokao_mathcloze": "hails/agieval-gaokao-mathcloze", + "agieval_gaokao_mathqa": "hails/agieval-gaokao-mathqa", + "agieval_math": "hails/agieval-math", + "agieval_sat_math": "hails/agieval-sat-math", + "cmmlu_college_mathematics": "haonan-li/cmmlu,college_mathematics", + "cmmlu_elementary_mathematics": "haonan-li/cmmlu,elementary_mathematics", + "cmmlu_high_school_mathematics": "haonan-li/cmmlu,high_school_mathematics", + "gsm8k": "gsm8k,main", + "mathqa": "allenai/math_qa", + "mgsm_native_cot_zh": "juletxara/mgsm,zh", + "minerva_math": "lighteval/MATH,all", + "abstract_algebra": "hails/mmlu_no_train,abstract_algebra", + "college_mathematics": "hails/mmlu_no_train,college_mathematics", + "elementary_mathematics": "hails/mmlu_no_train,elementary_mathematics", + "high_school_mathematics": "hails/mmlu_no_train,high_school_mathematics", +} + +USER_FIN = { + "australian": "ChanceFocus/flare-australian", + "cra_lendingclub": "ChanceFocus/cra-lendingclub", + "fiqasa": "ChanceFocus/flare-fiqasa", + "fpb": "ChanceFocus/en-fpb", + "german": "ChanceFocus/flare-german", + "headlines": "ChanceFocus/flare-headlines", + "ner": "ChanceFocus/flare-ner", + "sm_acl": "ChanceFocus/flare-sm-acl", + "sm_bigdata": "TheFinAI/en-forecasting-bigdata", + "sm_cikm": "ChanceFocus/flare-sm-cikm", + "causal20_sc": "ChanceFocus/flare-causal20-sc", + "finarg_ecc_arc": "ChanceFocus/flare-finarg-ecc-arc", + "finarg_ecc_auc": "ChanceFocus/flare-finarg-ecc-auc", + "fomc": "ChanceFocus/flare-fomc", + "ma": "ChanceFocus/flare-ma", + "mlesg": "ChanceFocus/flare-mlesg", + "multifin_en": "ChanceFocus/flare-multifin-en", +} + +LEARNWARE_MED_IDS = [ + "00002789", + "00002790", + "00002791", + "00002792", + "00002793", + "00002794", + "00002795", + "00002796", + "00002797", + "00002798", + "00002799", + "00002800", + "00002801", +] + +LEARNWARE_MATH_IDS = [ + "00002802", + "00002803", + "00002804", + "00002805", + "00002806", + "00002807", + "00002808", + "00002809", + "00002810", + "00002811", + "00002812", + "00002813", + "00002814", +] + + +LEARNWARE_FIN_IDS = [ + "00002815", + "00002816", + "00002817", + "00002818", + "00002819", + "00002820", + "00002821", + "00002822", + "00002823", + "00002824", + "00002825", + "00002826", + "00002827", + "00002828", + "00002829", + "00002830", + "00002831", + "00002832", + "00002833", + "00002834", + "00002835", + "00002836", + "00002837", + "00002838", + "00002839", + "00002840", + "00002841", + "00002842", + "00002843", + "00002844", + "00002845", + "00002846", + "00002847", + "00002848", + "00002849", + "00002850", + "00002851", + "00002852", + "00002853", + "00002854", + "00002855", + "00002856", + "00002857", + "00002858", + "00002859", + "00002860", + "00002861", + "00002862", + "00002863", + "00002864", + "00002865", + "00002866", + "00002867", + "00002868", + "00002869", + "00002870", + "00002871", + "00002872", + "00002873", + "00002874", + "00002875", + "00002876", + "00002877", + "00002878", + "00002879", + "00002880", + "00002881", + "00002882", + "00002883", + "00002884", + "00002885", + "00002886", + "00002887", + "00002888", + "00002889", +] diff --git a/examples/dataset_llm_workflow/benchmark/utils.py b/examples/dataset_llm_workflow/benchmark/utils.py new file mode 100644 index 00000000..2e62d26b --- /dev/null +++ b/examples/dataset_llm_workflow/benchmark/utils.py @@ -0,0 +1,500 @@ +import re +import random +from datasets import load_dataset, concatenate_datasets + +from .config import LEARNWARE_FIN, USER_FIN + + +def preprocess_alpaca(docs): + alpaca_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Input:\n{}\n\n### Response:\n{}" + instructions = docs["instruction"] + inputs = docs["input"] + outputs = docs["output"] + texts = [] + for instruction, input, output in zip(instructions, inputs, outputs): + text = alpaca_prompt.format(instruction, input, output) + texts.append(text) + return texts + + +def preprocess_alpaca_no_label(docs): + alpaca_no_label_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Input:\n{}\n\n### Response:\n" + instructions = docs["instruction"] + inputs = docs["input"] + texts = [] + for instruction, input in zip(instructions, inputs): + text = alpaca_no_label_prompt.format(instruction, input) + texts.append(text) + return texts + + +def preprocess_alpaca_no_input(docs): + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["instruction"] + outputs = docs["output"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_alpaca_no_input_no_label(docs): + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["instruction"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_qr(docs): + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["query"] + outputs = docs["response"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_qr_no_label(docs): + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["query"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_qr_zh(docs): + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["query_zh"] + outputs = docs["response_zh"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_qr_zh_no_label(docs): + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["query_zh"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_qa(docs): + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["question"] + outputs = docs["answer"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_qa_no_label(docs): + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["question"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_qa_zh(docs): + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["question_zh"] + outputs = docs["answer_zh"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_qa_zh_no_label(docs) -> str: + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["question_zh"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_finance(docs) -> str: + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["query"] + outputs = docs["answer"] + texts = [] + for instruction, output in zip(instructions, outputs): + instruction.rstrip(" Answer:") + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_math_train(docs) -> str: + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["question"] + outputs = docs["answer_detail"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +# Copied from Master +def preprocess_medmcqa(doc) -> str: + """ + Question: + Choices: + A. + B. + C. + D. + Answer: + """ + choices = [doc["opa"], doc["opb"], doc["opc"], doc["opd"]] + option_choices = { + "A": choices[0], + "B": choices[1], + "C": choices[2], + "D": choices[3], + } + + prompt = "Question: " + doc["question"] + "\nChoices:\n" + for choice, option in option_choices.items(): + prompt += f"{choice.upper()}. {option}\n" + prompt += "Answer:" + return prompt + + +def preprocess_medmcqa_val(docs): + opas = docs["opa"] + opbs = docs["opb"] + opcs = docs["opc"] + opds = docs["opd"] + questions = docs["question"] + option_ids = docs["cop"] + texts = [] + for opa, opb, opc, opd, question, option_id in zip(opas, opbs, opcs, opds, questions, option_ids): + option_choices = { + "A": opa, + "B": opb, + "C": opc, + "D": opd, + } + prompt = "Question: " + question + "\nChoices:\n" + for choice, option in option_choices.items(): + prompt += f"{choice.upper()}. {option}\n" + prompt += f"Answer: {list(option_choices.keys())[option_id]}" + texts.append(prompt) + return texts + + +def preprocess_medqa(doc) -> str: + option_choices = { + "A": doc["ending0"], + "B": doc["ending1"], + "C": doc["ending2"], + "D": doc["ending3"], + } + answers = "".join((f"{k}. {v}\n") for k, v in option_choices.items()) + return f"Question: {doc['sent1']}\n{answers}Answer:" + + +def preprocess_medqa_val(docs): + ending0s = docs["ending0"] + ending1s = docs["ending1"] + ending2s = docs["ending2"] + ending3s = docs["ending3"] + sent1s = docs["sent1"] + labels = docs["label"] + texts = [] + for sent1, ending0, ending1, ending2, ending3, label in zip(sent1s, ending0s, ending1s, ending2s, ending3s, labels): + option_choices = { + "A": ending0, + "B": ending1, + "C": ending2, + "D": ending3, + } + answers = "".join((f"{k}. {v}\n") for k, v in option_choices.items()) + texts.append(f"Question: {sent1}\n{answers}Answer: {list(option_choices.keys())[label]}") + return texts + + +def preprocess_mmlu(doc) -> str: + question = doc["question"].strip() + choices = doc["choices"] + return "{}\nA. {}\nB. {}\nC. {}\nD. {}\nAnswer:".format(question, choices[0], choices[1], choices[2], choices[3]) + + +def preprocess_mmlu_val(docs): + questions = docs["question"] + choices = docs["choices"] + answers = docs["answer"] + texts = [] + for question, options, answer in zip(questions, choices, answers): + texts.append( + "{}\nA. {}\nB. {}\nC. {}\nD. {}\nAnswer: {}".format( + question.strip(), options[0], options[1], options[2], options[3], ["A", "B", "C", "D"][answer] + ) + ) + return texts + + +def preprocess_pubmedqa(doc) -> str: + ctxs = "\n".join(doc["CONTEXTS"]) + return "Abstract: {}\nQuestion: {}\nAnswer:".format( + ctxs, + doc["QUESTION"], + ) + + +def preprocess_pubmedqa_val(docs): + contexts_list = docs["CONTEXTS"] + questions = docs["QUESTION"] + answers = docs["final_decision"] + texts = [] + for contexts, question, answer in zip(contexts_list, questions, answers): + ctxs = "\n".join(contexts) + texts.append("Abstract: {}\nQuestion: {}\nAnswer: {}".format(ctxs, question, answer)) + return texts + + +def preprocess_agieval(doc) -> str: + return doc["query"] + + +def preprocess_cmmlu(doc) -> str: + question = doc["Question"].strip() + return "{}\nA. {}\nB. {}\nC. {}\nD. {}\n答案:".format(question, doc["A"], doc["B"], doc["C"], doc["D"]) + + +def preprocess_cmmlu_val(docs): + questions = docs["Question"] + as_ = docs["A"] + bs = docs["B"] + cs = docs["C"] + ds = docs["D"] + answers = docs["Answer"] + texts = [] + for question, a, b, c, d, answer in zip(questions, as_, bs, cs, ds, answers): + texts.append("{}\nA. {}\nB. {}\nC. {}\nD. {}\n答案:{}".format(question.strip(), a, b, c, d, answer)) + return texts + + +def preprocess_mathqa(doc) -> str: + return "Question: {}\nAnswer:".format(doc["Problem"]) + + +def preprocess_mgsm(doc) -> str: + return "问题: " + doc["question"] + "\n逐步解答:" + + +def preprocess_gsm8k(doc) -> str: + return "Question: {}\nAnswer:".format(doc["question"]) + + +def preprocess_mathqa_val(docs): + problems = docs["Problem"] + corrects = docs["correct"] + options = docs["options"] + texts = [] + for problem, correct, option in zip(problems, corrects, options): + choices = [c[4:].rstrip(" ,") for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", option)] + + # answer = ['a', 'b', 'c', 'd', 'e'].index(correct) + texts.append( + "Question: {}\na. {}\nb. {}\nc. {}\nd. {}\ne. {}\nAnswer: {}".format( + problem, choices[0], choices[1], choices[2], choices[3], choices[4], correct + ) + ) + return texts + + +def preprocess_mgsm_val(docs): + questions = docs["question"] + answers = docs["answer"] + texts = [question + "\n" + answer for question, answer in zip(questions, answers)] + return texts + + +def preprocess_gsm8k_val(docs): + instructions = docs["question"] + outputs = docs["answer"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = f"Question: {instruction}\nAnswer: {output}" + texts.append(text) + return texts + + +def preprocess_math(doc: dict) -> str: + return "Problem:" + "\n" + doc["problem"] + "\n\n" + "Solution:" + + +def math_fewshot_prompt(doc: dict) -> str: + return "Problem:" + "\n" + doc["problem"] + "\n\n" + "Solution:" + doc["solution"] + + +def math_fewshot_samples() -> list[dict]: + return [ + { + "problem": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}", + "solution": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.", + "few_shot": "1", + }, + { + "problem": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$", + "solution": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$\nFinal Answer: The final answer is $24$. I hope it is correct.", + "few_shot": "1", + }, + { + "problem": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", + "solution": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$:\n\\begin{align*}\n30n&=480\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}\nFinal Answer: The final answer is $16$. I hope it is correct.", + "few_shot": "1", + }, + { + "problem": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero,\nfind $\\frac{a}{b},$ assuming $b$ is nonzero.", + "solution": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nFinal Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.", + "few_shot": "1", + }, + ] + + +def preprocess_finance_test(doc) -> str: + return doc["query"] + + +PROCESS_FUNC = { + # medical user + "openlifescienceai/medmcqa": preprocess_medmcqa, + "GBaker/MedQA-USMLE-4-options-hf": preprocess_medqa, + "hails/mmlu_no_train": preprocess_mmlu, + "bigbio/pubmed_qa": preprocess_pubmedqa, + # math user + "hails/agieval-gaokao-mathcloze": preprocess_agieval, + "hails/agieval-gaokao-mathqa": preprocess_agieval, + "hails/agieval-aqua-rat": preprocess_agieval, + "hails/agieval-math": preprocess_agieval, + "hails/agieval-sat-math": preprocess_agieval, + "haonan-li/cmmlu": preprocess_cmmlu, + "allenai/math_qa": preprocess_mathqa, + "juletxara/mgsm": preprocess_mgsm, + # "openai/gsm8k": preprocess_gsm8k, + # math learnware + "TIGER-Lab/MathInstruct": preprocess_alpaca_no_input_no_label, + "meta-math/MetaMathQA": preprocess_qr_no_label, + "meta-math/MetaMathQA-40K": preprocess_qr_no_label, + "fxmeng/MetaMath-GSM240K": preprocess_qr_no_label, + "meta-math/MetaMathQA_GSM8K_zh": preprocess_qr_zh_no_label, + "meta-math/GSM8K_zh": preprocess_qa_zh_no_label, + # "Dahoas/MATH-K-100-train": preprocess_math_k_100, + "ScalableMath/MATH_train-cleaned_processed": preprocess_qa_no_label, + "akjindal53244/Arithmo-Data": preprocess_qa_no_label, + "microsoft/orca-math-word-problems-200k": preprocess_qa_no_label, +} + + +PROCESS_FUNC_WITH_LABEL = { + # medical user + "openlifescienceai/medmcqa": preprocess_medmcqa_val, + "GBaker/MedQA-USMLE-4-options-hf": preprocess_medqa_val, + "hails/mmlu_no_train": preprocess_mmlu_val, + "bigbio/pubmed_qa": preprocess_pubmedqa_val, + # math user + "haonan-li/cmmlu": preprocess_cmmlu_val, + "allenai/math_qa": preprocess_mathqa_val, + "juletxara/mgsm": preprocess_mgsm_val, + "lighteval/MATH": preprocess_math_train, + "gsm8k": preprocess_gsm8k_val, + # math learnware + "TIGER-Lab/MathInstruct": preprocess_alpaca_no_input, + "meta-math/MetaMathQA": preprocess_qr, + "meta-math/MetaMathQA-40K": preprocess_qr, + "fxmeng/MetaMath-GSM240K": preprocess_qr, + "meta-math/MetaMathQA_GSM8K_zh": preprocess_qr_zh, + "meta-math/GSM8K_zh": preprocess_qa_zh, + # "Dahoas/MATH-K-100-train": preprocess_math_k_100, + "ScalableMath/MATH_train-cleaned_processed": preprocess_math_train, + "akjindal53244/Arithmo-Data": preprocess_qa, + "microsoft/orca-math-word-problems-200k": preprocess_qa, +} + + +def prepare_train_data(dataset_name_str): + if dataset_name_str in list(PROCESS_FUNC_WITH_LABEL.keys()): + dataset = load_dataset(dataset_name_str, split="train") + if dataset_name_str == "meta-math/GSM8K_zh": + dataset = dataset.filter(lambda x: x["split"] == "train") + dataset = dataset.map(lambda x: {"text": PROCESS_FUNC_WITH_LABEL[dataset_name_str](x)}, batched=True) + split_dataset = dataset.train_test_split(test_size=0.1) + train_dataset = split_dataset["train"] + val_dataset = split_dataset["test"] + elif dataset_name_str in list(LEARNWARE_FIN.values()): + train_dataset = load_dataset(dataset_name_str, split="train") + if "cra" not in dataset_name_str: + val_dataset = load_dataset(dataset_name_str, split="valid") + else: + val_dataset = load_dataset(dataset_name_str, split="validation") + train_dataset = train_dataset.map(lambda x: {"text": preprocess_finance(x)}, batched=True) + val_dataset = val_dataset.map(lambda x: {"text": preprocess_finance(x)}, batched=True) + else: + dataset_list = dataset_name_str.split(",") + train_datasets = [] + for dataset_name in dataset_list: + dataset = load_dataset(dataset_name, split="train") + dataset = dataset.remove_columns( + [col for col in dataset.column_names if col not in ["instruction", "input", "output"]] + ) + train_datasets.append(dataset) + combined_dataset = concatenate_datasets(train_datasets) + combined_dataset = combined_dataset.map(lambda x: {"text": preprocess_alpaca(x)}, batched=True) + split_dataset = combined_dataset.train_test_split(test_size=0.1) + train_dataset = split_dataset["train"] + val_dataset = split_dataset["test"] + + return train_dataset, val_dataset + + +def prepare_test_data(dataset_name_str): + temp_list = dataset_name_str.split(",") + subset_name = None + if len(temp_list) != 1: + subset_name = temp_list[1] + dataset_name = temp_list[0] + if subset_name: + test_dataset = load_dataset(dataset_name, subset_name, split="test") + else: + test_dataset = load_dataset(dataset_name, split="test") + + if dataset_name == "gsm8k": + rnd = random.Random(1234) + train_dataset = load_dataset(dataset_name, "main", split="train") + train_dataset = train_dataset.map(lambda x: {"text": preprocess_gsm8k_val(x)}, batched=True) + train_docs = train_dataset["text"] + fewshot_examples = rnd.sample(train_docs, 5) + fewshot_context = "\n\n".join(fewshot_examples) + "\n\n" + test_dataset = test_dataset.map(lambda x: {"text": fewshot_context + preprocess_gsm8k(x)}) + elif dataset_name == "lighteval/MATH": + fewshot_context = "\n\n".join([math_fewshot_prompt(example) for example in math_fewshot_samples()]) + "\n\n" + test_dataset = test_dataset.map(lambda x: {"text": fewshot_context + preprocess_math(x)}) + elif dataset_name in list(USER_FIN.values()): + test_dataset = test_dataset.map(lambda x: {"text": preprocess_finance_test(x)}) + else: + test_dataset = test_dataset.map(lambda x: {"text": PROCESS_FUNC[dataset_name](x)}) + return test_dataset diff --git a/examples/dataset_llm_workflow/eval_config.py b/examples/dataset_llm_workflow/eval_config.py new file mode 100644 index 00000000..e79de6bf --- /dev/null +++ b/examples/dataset_llm_workflow/eval_config.py @@ -0,0 +1,183 @@ +from typing import List + +from learnware.tests.benchmarks import LLMBenchmarkConfig + + +medical_eval_configs: List[LLMBenchmarkConfig] = [ + LLMBenchmarkConfig( + name="medmcqa", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="medqa_4options", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mmlu_anatomy", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mmlu_clinical_knowledge", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mmlu_college_biology", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mmlu_college_medicine", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mmlu_medical_genetics", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mmlu_professional_medicine", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="pubmedqa", + eval_metric="acc", + ), +] + +math_eval_configs: List[LLMBenchmarkConfig] = [ + LLMBenchmarkConfig( + name="agieval_aqua_rat", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="agieval_gaokao_mathcloze", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="agieval_gaokao_mathqa", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="agieval_math", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="agieval_sat_math", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="cmmlu_college_mathematics", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="cmmlu_elementary_mathematics", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="cmmlu_high_school_mathematics", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="gsm8k", + eval_metric="exact_match,flexible-extract", + ), + LLMBenchmarkConfig( + name="mathqa", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mgsm_native_cot_zh", + eval_metric="exact_match,flexible-extract", + ), + LLMBenchmarkConfig( + name="minerva_math", + eval_metric="exact_match", + ), + LLMBenchmarkConfig( + name="mmlu_abstract_algebra", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mmlu_college_mathematics", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mmlu_elementary_mathematics", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mmlu_high_school_mathematics", + eval_metric="acc", + ), +] + +finance_eval_configs: List[LLMBenchmarkConfig] = [ + LLMBenchmarkConfig( + name="australian", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="cra_lendingclub", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="fiqasa", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="fpb", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="german", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="headlines", + eval_metric="avg_f1", + ), + LLMBenchmarkConfig( + name="ner", + eval_metric="entity_f1", + ), + LLMBenchmarkConfig( + name="sm_acl", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="sm_bigdata", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="sm_cikm", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="causal20_sc", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="finarg_ecc_arc", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="finarg_ecc_auc", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="fomc", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="ma", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="mlesg", + eval_metric="acc", + ), + LLMBenchmarkConfig( + name="multifin_en", + eval_metric="acc", + ), +] + +CONFIG = {"medical": medical_eval_configs, "math": math_eval_configs, "finance": finance_eval_configs} diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/australian.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/australian.yaml new file mode 100644 index 00000000..2ba0dc42 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/australian.yaml @@ -0,0 +1,2 @@ +task: australian +class: !function flare.Australian diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/causal20_sc.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/causal20_sc.yaml new file mode 100644 index 00000000..03186515 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/causal20_sc.yaml @@ -0,0 +1,2 @@ +task: causal20_sc +class: !function flare.Causal20SC diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/cd.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/cd.yaml new file mode 100644 index 00000000..594b701d --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/cd.yaml @@ -0,0 +1,2 @@ +task: cd +class: !function flare.CD diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/convfinqa.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/convfinqa.yaml new file mode 100644 index 00000000..6806ef5d --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/convfinqa.yaml @@ -0,0 +1,2 @@ +task: convfinqa +class: !function flare.ConvFinQA diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/cra_ccf.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/cra_ccf.yaml new file mode 100644 index 00000000..5505edc8 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/cra_ccf.yaml @@ -0,0 +1,2 @@ +task: cra_ccf +class: !function flare.ccf \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/cra_ccfraud.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/cra_ccfraud.yaml new file mode 100644 index 00000000..9289db9b --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/cra_ccfraud.yaml @@ -0,0 +1,2 @@ +task: cra_ccfraud +class: !function flare.ccfraud diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/cra_lendingclub.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/cra_lendingclub.yaml new file mode 100644 index 00000000..de7609c2 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/cra_lendingclub.yaml @@ -0,0 +1,2 @@ +task: cra_lendingclub +class: !function flare.lendingclub diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/cra_polish.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/cra_polish.yaml new file mode 100644 index 00000000..3d3d50e1 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/cra_polish.yaml @@ -0,0 +1,2 @@ +task: cra_polish +class: !function flare.polish diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/cra_portoseguro.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/cra_portoseguro.yaml new file mode 100644 index 00000000..6c79245a --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/cra_portoseguro.yaml @@ -0,0 +1,2 @@ +task: cra_portoseguro +class: !function flare.portoseguro diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/cra_taiwan.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/cra_taiwan.yaml new file mode 100644 index 00000000..d0948067 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/cra_taiwan.yaml @@ -0,0 +1,2 @@ +task: cra_taiwan +class: !function flare.taiwan diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/cra_travelinsurace.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/cra_travelinsurace.yaml new file mode 100644 index 00000000..80d70661 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/cra_travelinsurace.yaml @@ -0,0 +1,2 @@ +task: cra_travelinsurace +class: !function flare.travelinsurace \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/ectsum.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/ectsum.yaml new file mode 100644 index 00000000..7bdc06a2 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/ectsum.yaml @@ -0,0 +1,2 @@ +task: ectsum +class: !function flare.ECTSUM \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/edtsum.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/edtsum.yaml new file mode 100644 index 00000000..7dbf158c --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/edtsum.yaml @@ -0,0 +1,2 @@ +task: edtsum +class: !function flare.EDTSUM \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/finarg_ecc_arc.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/finarg_ecc_arc.yaml new file mode 100644 index 00000000..bda9917e --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/finarg_ecc_arc.yaml @@ -0,0 +1,2 @@ +task: finarg_ecc_arc +class: !function flare.FINARGECCARC \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/finarg_ecc_auc.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/finarg_ecc_auc.yaml new file mode 100644 index 00000000..2a04806f --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/finarg_ecc_auc.yaml @@ -0,0 +1,2 @@ +task: finarg_ecc_auc +class: !function flare.FINARGECCAUC \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/finer_ord.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/finer_ord.yaml new file mode 100644 index 00000000..9ed571c6 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/finer_ord.yaml @@ -0,0 +1,2 @@ +task: finer_ord +class: !function flare.FinerOrd diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/finqa.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/finqa.yaml new file mode 100644 index 00000000..e13381d5 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/finqa.yaml @@ -0,0 +1,2 @@ +task: finqa +class: !function flare.FinQA diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/finred.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/finred.yaml new file mode 100644 index 00000000..0c43a65e --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/finred.yaml @@ -0,0 +1,2 @@ +task: finred +class: !function flare.FinRED diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/fiqasa.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/fiqasa.yaml new file mode 100644 index 00000000..e47c3e47 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/fiqasa.yaml @@ -0,0 +1,2 @@ +task: fiqasa +class: !function flare.FIQASA diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/flare.py b/examples/dataset_llm_workflow/extra_tasks/flare/flare.py new file mode 100644 index 00000000..1557bebe --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/flare.py @@ -0,0 +1,1582 @@ +""" +FLARE +""" + +from typing import List +from transformers import BartTokenizer, BartForConditionalGeneration +import traceback +import torch.nn as nn +import torch +from lm_eval.api.instance import Instance +import numpy as np +from seqeval.metrics import f1_score as entity_score +from sklearn.metrics import f1_score, matthews_corrcoef, mean_squared_error +import evaluate +import re +from lm_eval.api.task import ConfigurableTask +import os + + +def mean(arr): + return sum(arr) / len(arr) + + +def process_text(entity_string, text): + # Initialize + entity_list = [(", ".join(val.split(", ")[:-1]), val.split(", ")[-1]) for val in entity_string.split("\n")] + text_words = text.split() + labels = ["O"] * len(text_words) + # text_lower = text.lower() + text_lower = text + + # Create a list to store the start index of each word + word_indices = [0] + for word in text_words[:-1]: + word_indices.append(word_indices[-1] + len(word) + 1) + + # Iterate over the entity list + # print (entity_list) + for entity, entity_type in entity_list: + entity.split() + entity_lower = entity + + # Find start and end index of each occurrence of the entity in the text + start = 0 + while True: + start = text_lower.find(entity_lower, start) + if not entity or start == -1: + break # No more occurrence + end = start + len(entity) - 1 + + # Find the words included in this occurrence + try: + start_word = next(i for i, ind in enumerate(word_indices) if ind >= start) + end_word = next(i for i, ind in enumerate(word_indices) if ind > end) + + # Label the words + labels[start_word] = "B-" + entity_type + for i in range(start_word + 1, end_word): + labels[i] = "I-" + entity_type + + # Move to the next character after the occurrence + except Exception: + pass + start = end + 1 + + return labels + + +_CITATION = """ +@misc{xie2023pixiu, + title={PIXIU: A Large Language Model, Instruction Data and Evaluation Benchmark for Finance}, + author={Qianqian Xie and Weiguang Han and Xiao Zhang and Yanzhao Lai and Min Peng and Alejandro Lopez-Lira and Jimin Huang}, + year={2023}, + eprint={2306.05443}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +""" + + +class Classification(ConfigurableTask): + CALCULATE_MCC = True + LOWER_CASE = True + VERSION = 1 + EVAL_LAST_TURN = True + + def __init__(self, **kwargs): + super().__init__(config={"metadata": {"version": self.VERSION}}) + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + # cont_request = rf.greedy_until(ctx, {"until": None}) + # return cont_request + kwargs.pop("apply_chat_template") + return [ + Instance( + request_type="generate_until", + doc=doc, + arguments=(ctx, {}), + idx=0, + **kwargs, + ) + ] + + def doc_to_decontamination_query(self, doc): + return doc["text"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["answer"] + + def process_results(self, doc, results): + gold: str = doc["choices"][doc["gold"]] + if self.LOWER_CASE: + gold = gold.lower() + ini_result = results[0].strip() + if self.LOWER_CASE: + ini_result = ini_result.lower() + + result = None + for choice in doc["choices"]: + if self.LOWER_CASE: + choice = choice.lower() + if choice in ini_result: + result = choice + break + if result is None: + result = "missing" + + acc = 1.0 if gold == result else 0.0 + + results = { + "acc": acc, + "missing": int(result == "missing"), + "f1": (result, gold), + "macro_f1": (result, gold), + } + + if self.CALCULATE_MCC: + results["mcc"] = (result, gold) + + return results + + def higher_is_better(self): + metrics = { + "acc": True, + "f1": True, + "macro_f1": True, + "missing": False, + } + if self.CALCULATE_MCC: + metrics["mcc"] = True + return metrics + + def weighted_f1(self, items): + preds, golds = zip(*items) + labels = list(set(golds)) + preds = np.array(preds) + golds = np.array(golds) + f1 = f1_score(golds, preds, average="weighted", labels=labels) + return f1 + + def macro_f1(self, items): + preds, golds = zip(*items) + labels = list(set(golds)) + preds = np.array(preds) + golds = np.array(golds) + f1 = f1_score(golds, preds, average="macro", labels=labels) + return f1 + + def matthews_corrcoef(self, items): + preds, golds = zip(*items) + labels = {label: i for i, label in enumerate(list(set(golds)))} + preds = [labels.get(pred, -1) for pred in preds] + golds = [labels.get(gold, -1) for gold in golds] + return matthews_corrcoef(golds, preds) + + def aggregation(self): + metrics = { + "acc": mean, + "missing": mean, + "f1": self.weighted_f1, + "macro_f1": self.macro_f1, + } + if self.CALCULATE_MCC: + metrics["mcc"] = self.matthews_corrcoef + return metrics + + +class SequentialLabeling(ConfigurableTask): + VERSION = 1 + DATASET_NAME = None + LMAP = {"O": 0} + EVAL_LAST_TURN = True + + def __init__(self, **kwargs): + super().__init__(config={"metadata": {"version": self.VERSION}}) + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return "\nAnswer: " + doc["answer"] + + def process_results(self, doc, results): + return { + "entity_f1": (doc["label"], results[0], doc["token"]), + "f1": (doc["label"], results[0], doc["token"]), + } + + def higher_is_better(self): + return { + "f1": True, + "entity_f1": True, + } + + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + # cont_request = rf.greedy_until(ctx, {"until": None}) + # return cont_request + kwargs.pop("apply_chat_template") + return [ + Instance( + request_type="generate_until", + doc=doc, + arguments=(ctx, {}), + idx=0, + **kwargs, + ) + ] + + def process_result(self, pred, gold, tokens): + format_pred = ["O"] * len(gold) + for index, pre in enumerate(pred.split("\n")[: len(tokens)]): + try: + word, label = pre.split(":") + except BaseException: + continue + if word == tokens[index] and label in self.LMAP.keys(): + format_pred[index] = label + return format_pred + + def entity_f1(self, items): + golds, preds, tokens = zip(*items) + + list_preds = [self.process_result(pred, gold, token) for pred, gold, token in zip(preds, golds, tokens)] + f1 = entity_score(golds, list_preds) + return f1 + + def process_label_result(self, pred, gold, tokens): + format_pred = [-1] * len(gold) + for index, pre in enumerate(pred.split("\n")[: len(tokens)]): + try: + word, label = pre.split(":") + except BaseException: + continue + if word == tokens[index]: + format_pred[index] = self.LMAP.get(label, -1) + return format_pred + + def label_f1(self, items): + golds, preds, tokens = zip(*items) + + list_preds = [self.process_label_result(pred, gold, token) for pred, gold, token in zip(preds, golds, tokens)] + list_preds = [item for sublist in list_preds for item in sublist] + golds = [self.LMAP[item] for sublist in golds for item in sublist] + f1 = f1_score(golds, list_preds, average="weighted") + return f1 + + def aggregation(self): + return { + "entity_f1": self.entity_f1, + "f1": self.label_f1, + } + + +class AbstractiveSummarization(ConfigurableTask): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def __init__(self, **kwargs): + super().__init__(config={"metadata": {"version": self.VERSION}}) + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "rouge1": (doc["answer"], results[0]), + "rouge2": (doc["answer"], results[0]), + "rougeL": (doc["answer"], results[0]), + "bert_score_f1": (doc["answer"], results[0]), + "bart_score": (doc["answer"], results[0]), + } + + def higher_is_better(self): + return { + "rouge1": True, + "rouge2": True, + "rougeL": True, + "bert_score_f1": True, + "bart_score": True, + } + + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + # cont_request = rf.greedy_until(ctx, {"until": None}) + # return cont_request + kwargs.pop("apply_chat_template") + return [ + Instance( + request_type="generate_until", + doc=doc, + arguments=(ctx, {}), + idx=0, + **kwargs, + ) + ] + + def rouge_score(self, items): + golds, preds = zip(*items) + rouge = evaluate.load("rouge") + results = rouge.compute(predictions=preds, references=golds) + return results + + def rouge1(self, items): + results = self.rouge_score(items) + return results["rouge1"] + + def rouge2(self, items): + results = self.rouge_score(items) + return results["rouge2"] + + def rougeL(self, items): + results = self.rouge_score(items) + return results["rougeL"] + + def bert_score(self, items): + if getattr(self, "_cache_bertscore", None) is None: + golds, preds = zip(*items) + bertscore = evaluate.load("evaluate-metric/bertscore") + self._cache_bertscore = bertscore.compute( + predictions=preds, + references=golds, + model_type="bert-base-multilingual-cased", + ) + return self._cache_bertscore + else: + return self._cache_bertscore + + def bert_score_f1(self, items): + res = self.bert_score(items) + return sum(res["f1"]) / len(res["f1"]) + + def bart_score(self, items): + golds, preds = zip(*items) + bart_scorer = BARTScorer(device="cuda", checkpoint="facebook/bart-large-cnn") + bart_path = os.path.abspath(os.path.join(__file__, "..", "..", "..")) + bart_path = os.path.join(bart_path, "external_utils", "BARTScore", "bart_score.pth") + bart_scorer.load(path=bart_path) + res = bart_scorer.score(srcs=preds, tgts=golds, batch_size=8) + return sum(res) / len(res) + + def aggregation(self): + return { + "rouge1": self.rouge1, + "rouge2": self.rouge2, + "rougeL": self.rougeL, + "bert_score_f1": self.bert_score_f1, + "bart_score": self.bart_score, + } + + +class ExtractiveSummarization(ConfigurableTask): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def __init__(self, **kwargs): + super().__init__(config={"metadata": {"version": self.VERSION}}) + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "rouge1": (doc["label"], doc["text"], results[0]), + "rouge2": (doc["label"], doc["text"], results[0]), + "rougeL": (doc["label"], doc["text"], results[0]), + "bert_score_f1": (doc["label"], doc["text"], results[0]), + "bart_score": (doc["label"], doc["text"], results[0]), + } + + def higher_is_better(self): + return { + "rouge1": True, + "rouge2": True, + "rougeL": True, + "bert_score_f1": True, + "bart_score": True, + } + + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + # cont_request = rf.greedy_until(ctx, {"until": None}) + # return cont_request + kwargs.pop("apply_chat_template") + return [ + Instance( + request_type="generate_until", + doc=doc, + arguments=(ctx, {}), + idx=0, + **kwargs, + ) + ] + + def get_sum(self, labels, texts): + summ = [] + for label, text in zip(labels, texts): + text = text.split("\n") + new_text = "\n".join( + [text[index] for index in range(len(text)) if index < len(label) and label[index] == 1] + ) + summ.append(new_text) + return summ + + def rouge_score(self, items): + golds, texts, preds = zip(*items) + golds = self.get_sum(golds, texts) + preds = self.get_sum([val.split("\n") for val in preds], texts) + rouge = evaluate.load("rouge") + results = rouge.compute(predictions=preds, references=golds) + return results + + def rouge1(self, items): + results = self.rouge_score(items) + return results["rouge1"] + + def rouge2(self, items): + results = self.rouge_score(items) + return results["rouge2"] + + def rougeL(self, items): + results = self.rouge_score(items) + return results["rougeL"] + + def bert_score(self, items): + if getattr(self, "_cache_bertscore", None) is None: + golds, texts, preds = zip(*items) + golds = self.get_sum(golds, texts) + preds = self.get_sum([val.split("\n") for val in preds], texts) + + bertscore = evaluate.load("evaluate-metric/bertscore") + self._cache_bertscore = bertscore.compute( + predictions=preds, + references=golds, + model_type="bert-base-multilingual-cased", + ) + return self._cache_bertscore + else: + return self._cache_bertscore + + def bert_score_f1(self, items): + res = self.bert_score(items) + return sum(res["f1"]) / len(res["f1"]) + + def bart_score(self, items): + golds, texts, preds = zip(*items) + golds = self.get_sum(golds, texts) + preds = self.get_sum([val.split("\n") for val in preds], texts) + + bart_scorer = BARTScorer(device="cuda", checkpoint="facebook/bart-large-cnn") + bart_path = os.path.abspath(os.path.join(__file__, "..", "..", "..")) + bart_path = os.path.join(bart_path, "external_utils", "BARTScore", "bart_score.pth") + bart_scorer.load(path=bart_path) + res = bart_scorer.score(srcs=preds, tgts=golds, batch_size=8) + return sum(res) / len(res) + + def aggregation(self): + return { + "rouge1": self.rouge1, + "rouge2": self.rouge2, + "rougeL": self.rougeL, + "bert_score_f1": self.bert_score_f1, + "bart_score": self.bart_score, + } + + +class RelationExtraction(ConfigurableTask): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def __init__(self, **kwargs): + super().__init__(config={"metadata": {"version": self.VERSION}}) + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "precision": (doc["label"], results[0]), + "recall": (doc["label"], results[0]), + "f1": (doc["label"], results[0]), + } + + def higher_is_better(self): + return { + "precision": True, + "recall": True, + "f1": True, + } + + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + # cont_request = rf.greedy_until(ctx, {"until": None}) + # return cont_request + kwargs.pop("apply_chat_template") + return [ + Instance( + request_type="generate_until", + doc=doc, + arguments=(ctx, {}), + idx=0, + **kwargs, + ) + ] + + def process(self, items): + golds, preds = zip(*items) + + all_golds = [] + all_preds = [] + + for gold, pred in zip(golds, preds): + all_golds.extend(gold) + pred = pred.split("\n") + all_preds.extend(pred) + + return set(all_golds), set(all_preds) + + def precision(self, items): + golds, preds = self.process(items) + tp = golds & preds + prec = len(tp) / len(preds) + return prec + + def recall(self, items): + golds, preds = self.process(items) + tp = golds & preds + rec = len(tp) / len(golds) + return rec + + def cal_f1(self, items): + prec = self.precision(items) + rec = self.recall(items) + if prec + rec == 0.0: + return 0.0 + return 2 * (prec * rec) / (prec + rec) + + def aggregation(self): + return { + "precision": self.precision, + "recall": self.recall, + "f1": self.cal_f1, + } + + +class QA(ConfigurableTask): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def __init__(self, **kwargs): + super().__init__(config={"metadata": {"version": self.VERSION}}) + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def should_decontaminate(self): + return True + + def doc_to_decontamination_query(self, doc): + return doc["text"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + # cont_request = rf.greedy_until(ctx, {"until": None}) + # return cont_request + kwargs.pop("apply_chat_template") + return [ + Instance( + request_type="generate_until", + doc=doc, + arguments=(ctx, {}), + idx=0, + **kwargs, + ) + ] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + gold = doc["answer"] + + acc = 1.0 if results[0].strip() == gold else 0.0 + + return { + "acc": acc, + } + + def higher_is_better(self): + return { + "acc": True, + } + + def aggregation(self): + return { + "acc": mean, + } + + +class FPB(Classification): + DATASET_PATH = "chancefocus/flare-fpb" + + +class FIQASA(Classification): + DATASET_PATH = "chancefocus/flare-fiqasa" + + +class NER(ConfigurableTask): + VERSION = 1 + DATASET_PATH = "chancefocus/flare-ner" + DATASET_NAME = None + EVAL_LAST_TURN = True + + def __init__(self, **kwargs): + super().__init__(config={"metadata": {"version": self.VERSION}}) + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def should_decontaminate(self): + return True + + def doc_to_decontamination_query(self, doc): + return doc["text"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + kwargs.pop("apply_chat_template") + return [ + Instance( + request_type="generate_until", + doc=doc, + arguments=(ctx, {}), + idx=0, + **kwargs, + ) + ] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + text = doc["text"] + pred = process_text(results[0], text) + + return {"entity_f1": (pred, doc["label"], results[0])} + + def higher_is_better(self): + return { + "entity_f1": True, + } + + @classmethod + def entity_f1(cls, items): + preds, golds, _ = zip(*items) + f1 = entity_score(golds, preds) + return f1 + + def aggregation(self): + return { + "entity_f1": self.entity_f1, + } + + +class FinQA(QA): + DATASET_PATH = "chancefocus/flare-finqa" + + +class StockMovement(Classification): + DATASET_NAME = None + CALCULATE_MCC = True + CHOICE_DICT = { + "rise": ["yes", "positive"], + "fall": ["no", "negative", "neutral"], + } + DEFAULT = "fall" + + def process_results(self, doc, results): + gold: str = doc["choices"][doc["gold"]] + if self.LOWER_CASE: + gold = gold.lower() + ini_result = results[0].strip() + if self.LOWER_CASE: + ini_result = ini_result.lower() + + result = None + for choice in doc["choices"]: + if self.LOWER_CASE: + choice = choice.lower() + if choice in ini_result or any([val in ini_result for val in self.CHOICE_DICT[choice]]): + result = choice + break + if result is None: + result = self.DEFAULT + + acc = 1.0 if gold == result else 0.0 + + results = { + "acc": acc, + "missing": int(result == "missing"), + "f1": (result, gold), + "macro_f1": (result, gold), + } + + if self.CALCULATE_MCC: + results["mcc"] = (result, gold) + + return results + + +class StockMovementBigData(StockMovement): + DATASET_PATH = "chancefocus/flare-sm-bigdata" + + +class StockMovementACL(StockMovement): + DATASET_PATH = "chancefocus/flare-sm-acl" + + +class StockMovementCIKM(StockMovement): + DATASET_PATH = "chancefocus/flare-sm-cikm" + + +SM_TASKS = { + "flare_sm_bigdata": StockMovementBigData, + "flare_sm_acl": StockMovementACL, + "flare_sm_cikm": StockMovementCIKM, +} + + +class Headlines(Classification): + DATASET_PATH = "chancefocus/flare-headlines" + + def process_results(self, doc, results): + gold = doc["gold"] + + return { + "avg_f1": (doc["label_type"], int(results[0].strip() != "Yes"), gold, results), + } + + def higher_is_better(self): + return { + "avg_f1": True, + } + + @classmethod + def label_avg(cls, items): + labels, preds, golds, rels = zip(*items) + label_set = set(labels) + labels = np.array(labels) + preds = np.array(preds) + golds = np.array(golds) + all_f1s = [] + for label_val in label_set: + pds = preds[labels == label_val] + gds = golds[labels == label_val] + f1 = f1_score(gds, pds, average="weighted", labels=[0, 1]) + all_f1s.append(f1) + return np.mean(all_f1s) + + # def construct_requests(self, doc, ctx): + # """Uses RequestFactory to construct Requests and returns an iterable of + # Requests which will be sent to the LM. + + # :param doc: + # The document as returned from training_docs, validation_docs, or test_docs. + # :param ctx: str + # The context string, generated by fewshot_context. This includes the natural + # language description, as well as the few shot examples, and the question + # part of the document for `doc`. + # """ + # cont_request = rf.greedy_until(ctx, {"until": None}) + # return cont_request + + def aggregation(self): + return { + "avg_f1": self.label_avg, + } + + +class FinerOrd(SequentialLabeling): + DATASET_PATH = "chancefocus/flare-finer-ord" + LMAP = { + "O": 0, + "B-PER": 1, + "I-PER": 2, + "B-LOC": 3, + "I-LOC": 4, + "B-ORG": 5, + "I-ORG": 6, + } + + +class FOMC(Classification): + # DATASET_PATH = "chancefocus/flare-fomc" + DATASET_PATH = "TheFinAI/finben-fomc" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + +class German(StockMovement): + DATASET_PATH = "chancefocus/flare-german" + CHOICE_DICT = { + "good": ["yes", "positive"], + "bad": ["no", "negative", "neutral"], + } + DEFAULT = "good" + + +class Australian(StockMovement): + # DATASET_PATH = "chancefocus/flare-australian" + DATASET_PATH = "TheFinAI/flare-australian" + CHOICE_DICT = { + "good": ["yes", "positive"], + "bad": ["no", "negative", "neutral"], + } + DEFAULT = "good" + + +class ECTSUM(ExtractiveSummarization): + DATASET_PATH = "chancefocus/flare-ectsum" + + +class EDTSUM(AbstractiveSummarization): + DATASET_PATH = "chancefocus/flare-edtsum" + + +class EDTSUM_test(AbstractiveSummarization): + DATASET_PATH = "TheFinAI/flare-edtsum_test" + + +class ConvFinQA(QA): + DATASET_PATH = "chancefocus/flare-convfinqa" + + def reformulate_turn_req(self, req, turn_request, turn): + if turn == 0: + return req + pre_answers = {f"answer{i}": turn_request[i][0] for i in range(turn)} + if pre_answers: + req.args = tuple([req.args[0].format(**pre_answers)] + list(req.args[1:])) + return req + + +class TSA(ConfigurableTask): + VERSION = 1 + DATASET_PATH = "chancefocus/flare-tsa" + DATASET_NAME = None + EVAL_LAST_TURN = True + + def __init__(self, **kwargs): + super().__init__(config={"metadata": {"version": self.VERSION}}) + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return "\nAnswer: " + str(doc["answer"]) + + def process_results(self, doc, results): + pred = results[0].split("\n")[0] + pred = re.findall(r"[0-9]+(?:\.[0-9]+)?", pred) + missing = 0 + if not pred: + pred = -100.0 + missing = 1 + else: + pred = pred[0] + pred = float(pred) + return {"rmse": (doc["answer"], pred), "missing": missing} + + def higher_is_better(self): + return { + "rmse": False, + } + + def construct_requests(self, doc, ctx, **kwargs): + """ + Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + # cont_request = rf.greedy_until(ctx, {"until": "Answer:"}) + # return cont_request + kwargs.pop("apply_chat_template") + return [ + Instance( + request_type="generate_until", + doc=doc, + arguments=(ctx, {"until": "Answer:"}), + idx=0, + **kwargs, + ) + ] + + def rmse(self, items): + golds, preds = zip(*items) + fgolds, fpreds = [], [] + for gold, pred in zip(golds, preds): + if pred == -100.0: + continue + fgolds.append(gold) + fpreds.append(max(min(pred, 1.0), -1.0)) + rmse = mean_squared_error(fgolds, fpreds, squared=True) + + return rmse + + def aggregation(self): + return { + "rmse": self.rmse, + "missing": mean, + } + + +class CFA(Classification): + DATASET_PATH = "chancefocus/flare-cfa" + LOWER_CASE = False + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + +class FINARGECCARC(Classification): + DATASET_PATH = "chancefocus/flare-finarg-ecc-arc" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + +class FINARGECCAUC(Classification): + DATASET_PATH = "chancefocus/flare-finarg-ecc-auc" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + +class FINARGECCAUC_test(Classification): + DATASET_PATH = "TheFinAI/flare-finarg-ecc-auc_test" + + +class MLESG(Classification): + DATASET_PATH = "chancefocus/flare-mlesg" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + +class FSRL(SequentialLabeling): + DATASET_PATH = "chancefocus/flare-fsrl" + LMAP = { + key: index + for index, key in enumerate( + [ + "O", + "I-QUANT", + "B-QUANT", + "I-TIME", + "B-TIME", + "I-MANNER", + "B-MANNER", + "I-THEME", + "B-THEME", + "I-VALUE", + "B-VALUE", + "I-WHOLE", + "B-WHOLE", + "I-LOCATION", + "B-LOCATION", + "I-AGENT", + "B-AGENT", + "I-CAUSE", + "B-CAUSE", + "I-SOURCE", + "B-SOURCE", + "I-REF_TIME", + "B-REF_TIME", + "I-CONDITION", + "B-CONDITION", + ] + ) + } + + +# This class is already defined above at line 1200 +# class CFA(Classification): +# DATASET_PATH = "chancefocus/flare-cfa" +# +# def has_training_docs(self): +# return False +# +# def has_validation_docs(self): +# return False + + +# class FinargECCAUC(Classification): +# DATASET_PATH = "chancefocus/flare-finarg-ecc-auc" + +# class FinargECCARC(Classification): +# DATASET_PATH = "chancefocus/flare-finarg-ecc-arc" + + +class CD(SequentialLabeling): + DATASET_PATH = "chancefocus/flare-cd" + LMAP = {key: index for index, key in enumerate(["O", "I-CAUSE", "B-CAUSE", "I-EFFECT", "B-EFFECT"])} + + +class MultiFinEN(Classification): + DATASET_PATH = "chancefocus/flare-multifin-en" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + +class MA(Classification): + DATASET_PATH = "chancefocus/flare-ma" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + +class Causal20SC(Classification): + DATASET_PATH = "chancefocus/flare-causal20-sc" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + +class FNXL(SequentialLabeling): + DATASET_PATH = "chancefocus/flare-fnxl" + LMAP = { + "B-BusinessCombinationContingentConsiderationArrangementsRangeOfOutcomesValueHigh": 140, + "B-VariableInterestEntityOwnershipPercentage": 646, + "B-GainLossOnDispositionOfAssets1": 119, + "B-IndefiniteLivedIntangibleAssetsExcludingGoodwill": 46, + "B-MarketingAndAdvertisingExpense": 269, + "B-ReportingUnitPercentageOfFairValueInExcessOfCarryingAmount": 142, + "B-CapitalizedComputerSoftwareNet": 91, + "B-BusinessCombinationConsiderationTransferredEquityInterestsIssuedAndIssuable": 183, + "B-LitigationSettlementExpense": 115, + "B-DefinedBenefitPlanExpectedAmortizationOfGainLossNextFiscalYear": 639, + "B-DeferredCompensationArrangementWithIndividualCompensationExpense": 15, + "B-ReclassificationFromAociCurrentPeriodTax": 152, + "B-OtherComprehensiveIncomeLossBeforeReclassificationsTax": 694, + "B-PreferredStockDividendsPerShareDeclared": 236, + "B-CapitalExpendituresIncurredButNotYetPaid": 344, + "B-DeferredCompensationArrangementWithIndividualContributionsByEmployer": 560, + "B-SeveranceCosts1": 311, + "B-InterestExpense": 784, + "B-SaleOfStockConsiderationReceivedOnTransaction": 76, + "B-LineOfCreditFacilityInterestRateAtPeriodEnd": 822, + "B-SharesIssuedPricePerShare": 137, + "B-EquityMethodInvestmentDifferenceBetweenCarryingAmountAndUnderlyingEquity": 63, + "B-EquitySecuritiesFvNi": 30, + "B-RightOfUseAssetObtainedInExchangeForOperatingLeaseLiability": 118, + "B-DefinedBenefitPlanFundedStatusOfPlan": 547, + "B-SharebasedCompensationArrangementBySharebasedPaymentAwardPurchasePriceOfCommonStockPercent": 323, + "B-TaxCutsAndJobsActOf2017IncomeTaxExpenseBenefit": 256, + "B-LongtermDebtWeightedAverageInterestRate": 364, + "B-ImpairmentOfIntangibleAssetsFinitelived": 71, + "B-ProceedsFromLinesOfCredit": 496, + "B-LongTermPurchaseCommitmentAmount": 701, + "B-DebtInstrumentFairValue": 335, + "B-RestructuringAndRelatedCostCostIncurredToDate1": 52, + "B-ShareBasedCompensationArrangementByShareBasedPaymentAwardEquityInstrumentsOtherThanOptionsVestedInPeriod": 581, + "B-FiniteLivedIntangibleAssetsAccumulatedAmortization": 143, + "B-StockRepurchasedAndRetiredDuringPeriodValue": 330, + "B-BusinessCombinationProFormaInformationRevenueOfAcquireeSinceAcquisitionDateActual": 77, + "B-ClassOfWarrantOrRightExercisePriceOfWarrantsOrRights1": 361, + "B-BusinessAcquisitionPurchasePriceAllocationGoodwillExpectedTaxDeductibleAmount": 550, + "B-OperatingLossCarryforwardsValuationAllowance": 173, + "B-BusinessAcquisitionEquityInterestsIssuedOrIssuableNumberOfSharesIssued": 32, + "B-DefinedContributionPlanMaximumAnnualContributionsPerEmployeePercent": 45, + "B-ContractWithCustomerLiabilityCurrent": 2, + "B-IncomeLossFromContinuingOperationsBeforeIncomeTaxesForeign": 474, + "B-FiniteLivedIntangibleAssetsAmortizationExpenseYearThree": 1306, + "B-DefinedBenefitPlanUltimateHealthCareCostTrendRate1": 62, + "B-DefinedBenefitPlanRecognizedNetGainLossDueToSettlements1": 317, + "B-UnrecognizedTaxBenefitsInterestOnIncomeTaxesExpense": 448, + "B-ForeignCurrencyTransactionGainLossRealized": 132, + "B-DeferredTaxAssetsOperatingLossCarryforwardsSubjectToExpiration": 262, + "B-RetainedEarningsAccumulatedDeficit": 174, + "B-ProceedsFromIssuanceOfCommonStock": 209, + "B-EmployeeServiceShareBasedCompensationAllocationOfRecognizedPeriodCostsCapitalizedAmount": 29, + "B-OtherComprehensiveIncomeLossPensionAndOtherPostretirementBenefitPlansTax": 284, + "B-InventoryWriteDown": 465, + "B-RestructuringReserve": 234, + "B-LitigationSettlementAmountAwardedToOtherParty": 42, + "B-DerivativeGainLossOnDerivativeNet": 87, + "B-SharebasedCompensationArrangementBySharebasedPaymentAwardEquityInstrumentsOtherThanOptionsAggregateIntrinsicValueVested": 241, + "B-DerivativeFixedInterestRate": 589, + "B-CashAndCashEquivalentsAtCarryingValue": 257, + "B-ContractWithCustomerAssetNet": 245, + "B-RestructuringAndRelatedCostExpectedCost1": 107, + "B-IncomeTaxHolidayAggregateDollarAmount": 347, + "B-OperatingLeaseCost": 248, + "B-AllowanceForDoubtfulAccountsReceivable": 146, + "B-RepaymentsOfDebt": 416, + "B-InterestPaid": 110, + "B-DeferredFinanceCostsNet": 28, + "B-IncomeTaxExaminationPenaltiesAndInterestAccrued": 271, + "B-ShareBasedCompensationArrangementByShareBasedPaymentAwardEquityInstrumentsOtherThanOptionsNonvestedNumber": 92, + "B-CapitalizedContractCostNet": 155, + "B-CumulativeEffectOfNewAccountingPrincipleInPeriodOfAdoption": 17, + "B-IncomeTaxesPaid": 495, + "B-EquityMethodInvestmentOtherThanTemporaryImpairment": 22, + "B-InterestPaidNet": 225, + "B-EquitySecuritiesWithoutReadilyDeterminableFairValueAmount": 175, + "B-ImpairmentOfLongLivedAssetsHeldForUse": 313, + "B-GoodwillAcquiredDuringPeriod": 156, + "B-DecreaseInUnrecognizedTaxBenefitsIsReasonablyPossible": 363, + "B-RestructuringAndRelatedCostIncurredCost": 75, + "B-StockRepurchasedDuringPeriodValue": 254, + "B-IncomeTaxExaminationPenaltiesAndInterestExpense": 525, + "B-ImpairmentOfIntangibleAssetsIndefinitelivedExcludingGoodwill": 55, + "B-PreferredStockLiquidationPreference": 157, + "B-ImpairmentOfIntangibleAssetsExcludingGoodwill": 158, + "B-IncomeTaxesPaidNet": 456, + "B-DefinedContributionPlanEmployerMatchingContributionPercent": 332, + "B-CostOfGoodsAndServicesSold": 274, + "B-DepreciationDepletionAndAmortization": 338, + "B-InterestExpenseDebt": 191, + "B-LineOfCreditFacilityUnusedCapacityCommitmentFeePercentage": 442, + "B-DisposalGroupIncludingDiscontinuedOperationConsideration": 6, + "B-UnrecognizedTaxBenefitsInterestOnIncomeTaxesAccrued": 14, + "B-SaleOfStockPricePerShare": 278, + "B-DefinedContributionPlanEmployerMatchingContributionPercentOfMatch": 267, + "B-FinitelivedIntangibleAssetsAcquired1": 202, + "B-PaymentsForRepurchaseOfCommonStock": 486, + "B-BusinessCombinationContingentConsiderationLiability": 103, + "B-RelatedPartyTransactionAmountsOfTransaction": 180, + "O": 0, + } + + +class TATQA(QA): + DATASET_PATH = "chancefocus/flare-tatqa" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + +class FinRED(RelationExtraction): + DATASET_PATH = "chancefocus/flare-finred" + + +class lendingclub(Classification): + # DATASET_PATH = "chancefocus/cra-lendingclub" + DATASET_PATH = "TheFinAI/cra-lendingclub" + CALCULATE_MCC = True + + +class ccf(Classification): + DATASET_PATH = "chancefocus/cra-ccf" + CALCULATE_MCC = True + + +class ccfraud(Classification): + DATASET_PATH = "chancefocus/cra-ccfraud" + CALCULATE_MCC = True + + +class polish(Classification): + DATASET_PATH = "chancefocus/cra-polish" + CALCULATE_MCC = True + + +class taiwan(Classification): + DATASET_PATH = "chancefocus/cra-taiwan" + CALCULATE_MCC = True + + +class portoseguro(Classification): + DATASET_PATH = "chancefocus/cra-portoseguro" + CALCULATE_MCC = True + + +class travelinsurace(Classification): + DATASET_PATH = "chancefocus/cra-travelinsurace" + CALCULATE_MCC = True + + +############### + +# %% + + +class BARTScorer: + def __init__(self, device="cuda:0", max_length=1024, checkpoint="facebook/bart-large-cnn"): + # Set up model + self.device = device + self.max_length = max_length + self.tokenizer = BartTokenizer.from_pretrained(checkpoint) + self.model = BartForConditionalGeneration.from_pretrained(checkpoint) + self.model.eval() + self.model.to(device) + + # Set up loss + self.loss_fct = nn.NLLLoss(reduction="none", ignore_index=self.model.config.pad_token_id) + self.lsm = nn.LogSoftmax(dim=1) + + def load(self, path=None): + """Load model from paraphrase finetuning""" + if path is None: + path = "models/bart.pth" + self.model.load_state_dict(torch.load(path, map_location=self.device)) + + def score(self, srcs, tgts, batch_size=4): + """Score a batch of examples""" + score_list = [] + for i in range(0, len(srcs), batch_size): + src_list = srcs[i : i + batch_size] + tgt_list = tgts[i : i + batch_size] + try: + with torch.no_grad(): + encoded_src = self.tokenizer( + src_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + encoded_tgt = self.tokenizer( + tgt_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + src_tokens = encoded_src["input_ids"].to(self.device) + src_mask = encoded_src["attention_mask"].to(self.device) + + tgt_tokens = encoded_tgt["input_ids"].to(self.device) + tgt_mask = encoded_tgt["attention_mask"] + tgt_len = tgt_mask.sum(dim=1).to(self.device) + + output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens) + logits = output.logits.view(-1, self.model.config.vocab_size) + loss = self.loss_fct(self.lsm(logits), tgt_tokens.view(-1)) + loss = loss.view(tgt_tokens.shape[0], -1) + loss = loss.sum(dim=1) / tgt_len + curr_score_list = [-x.item() for x in loss] + score_list += curr_score_list + + except RuntimeError: + traceback.print_exc() + print(f"source: {src_list}") + print(f"target: {tgt_list}") + exit(0) + return score_list + + def multi_ref_score(self, srcs, tgts: List[List[str]], agg="mean", batch_size=4): + # Assert we have the same number of references + ref_nums = [len(x) for x in tgts] + if len(set(ref_nums)) > 1: + raise Exception("You have different number of references per test sample.") + + ref_num = len(tgts[0]) + score_matrix = [] + for i in range(ref_num): + curr_tgts = [x[i] for x in tgts] + scores = self.score(srcs, curr_tgts, batch_size) + score_matrix.append(scores) + if agg == "mean": + score_list = np.mean(score_matrix, axis=0) + elif agg == "max": + score_list = np.max(score_matrix, axis=0) + else: + raise NotImplementedError + return list(score_list) + + def test(self, batch_size=3): + """Test""" + src_list = [ + "This is a very good idea. Although simple, but very insightful.", + "Can I take a look?", + "Do not trust him, he is a liar.", + ] + + tgt_list = ["That's stupid.", "What's the problem?", "He is trustworthy."] + + print(self.score(src_list, tgt_list, batch_size)) diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/fnxl.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/fnxl.yaml new file mode 100644 index 00000000..481dedba --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/fnxl.yaml @@ -0,0 +1,2 @@ +task: fnxl +class: !function flare.FNXL diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/fomc.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/fomc.yaml new file mode 100644 index 00000000..49283578 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/fomc.yaml @@ -0,0 +1,2 @@ +task: fomc +class: !function flare.FOMC \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/fpb.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/fpb.yaml new file mode 100644 index 00000000..d0f31f7d --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/fpb.yaml @@ -0,0 +1,2 @@ +task: fpb +class: !function flare.FPB diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/fsrl.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/fsrl.yaml new file mode 100644 index 00000000..b63cc9a4 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/fsrl.yaml @@ -0,0 +1,2 @@ +task: fsrl +class: !function flare.FSRL diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/german.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/german.yaml new file mode 100644 index 00000000..ac900011 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/german.yaml @@ -0,0 +1,2 @@ +task: german +class: !function flare.German diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/headlines.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/headlines.yaml new file mode 100644 index 00000000..6178d264 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/headlines.yaml @@ -0,0 +1,2 @@ +task: headlines +class: !function flare.Headlines \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/ma.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/ma.yaml new file mode 100644 index 00000000..481b3efe --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/ma.yaml @@ -0,0 +1,2 @@ +task: ma +class: !function flare.MA \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/mlesg.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/mlesg.yaml new file mode 100644 index 00000000..f32b23ec --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/mlesg.yaml @@ -0,0 +1,2 @@ +task: mlesg +class: !function flare.MLESG \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/multifin_en.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/multifin_en.yaml new file mode 100644 index 00000000..091e8e4b --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/multifin_en.yaml @@ -0,0 +1,2 @@ +task: multifin_en +class: !function flare.MultiFinEN \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/ner.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/ner.yaml new file mode 100644 index 00000000..88074566 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/ner.yaml @@ -0,0 +1,2 @@ +task: ner +class: !function flare.NER diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/sm_acl.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/sm_acl.yaml new file mode 100644 index 00000000..5049f2a5 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/sm_acl.yaml @@ -0,0 +1,2 @@ +task: sm_acl +class: !function flare.StockMovementACL \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/sm_bigdata.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/sm_bigdata.yaml new file mode 100644 index 00000000..abdaa2da --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/sm_bigdata.yaml @@ -0,0 +1,2 @@ +task: sm_bigdata +class: !function flare.StockMovementBigData \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/sm_cikm.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/sm_cikm.yaml new file mode 100644 index 00000000..6d52f730 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/sm_cikm.yaml @@ -0,0 +1,2 @@ +task: sm_cikm +class: !function flare.StockMovementCIKM \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/tatqa.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/tatqa.yaml new file mode 100644 index 00000000..8cf461e3 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/tatqa.yaml @@ -0,0 +1,2 @@ +task: tatqa +class: !function flare.TATQA \ No newline at end of file diff --git a/examples/dataset_llm_workflow/extra_tasks/flare/tsa.yaml b/examples/dataset_llm_workflow/extra_tasks/flare/tsa.yaml new file mode 100644 index 00000000..efd902f0 --- /dev/null +++ b/examples/dataset_llm_workflow/extra_tasks/flare/tsa.yaml @@ -0,0 +1,2 @@ +task: tsa +class: !function flare.TSA diff --git a/examples/dataset_llm_workflow/model_performance/finance.csv b/examples/dataset_llm_workflow/model_performance/finance.csv new file mode 100644 index 00000000..a766dbb7 --- /dev/null +++ b/examples/dataset_llm_workflow/model_performance/finance.csv @@ -0,0 +1,19 @@ +Dataset,Qwen2.5-7B,Llama3.1-8B-Instruct,Llama3.1-8B,Qwen1.5-110B,Qwen2.5-72B,Llama3.1-70B-Instruct,australian-1,australian-2,australian-3,australian-4,australian-5,australian-6,australian-7,australian-8,cra_lendingclub-1,cra_lendingclub-2,cra_lendingclub-3,cra_lendingclub-4,cra_lendingclub-5,cra_lendingclub-6,fpb-1,fpb-2,fpb-3,fpb-4,german-1,german-2,german-3,german-4,german-5,german-6,german-7,german-8,headlines-1,headlines-2,headlines-3,ner-1,ner-2,ner-3,ner-4,sm_cikm-1,sm_cikm-2,sm_cikm-3,sm_cikm-4,sm_cikm-5,sm_cikm-6,sm_cikm-7,sm_cikm-8,sm_cikm-9,sm_cikm-10,sm_cikm-11,sm_cikm-12,sm_cikm-13,sm_cikm-14,sm_cikm-15,sm_cikm-16,sm_cikm-17,sm_acl-1,sm_acl-2,sm_acl-3,sm_acl-4,sm_acl-5,sm_acl-6,sm_acl-7,sm_acl-8,sm_bigdata-1,sm_bigdata-2,sm_bigdata-3,sm_bigdata-4,sm_bigdata-5,sm_bigdata-6,sm_bigdata-7,sm_bigdata-8,sm_bigdata-9,fiqasa-1,fiqasa-2,fiqasa-3,fiqasa-4,fiqasa-5,fiqasa-6,fiqasa-7,fiqasa-8 +australian,43.17,44.6,43.17,43.17,43.17,47.48,66.91,46.76,47.48,63.31,47.48,53.96,69.06,59.71,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,42.45,41.73,42.45,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.88,43.17,43.17,43.17,43.17,43.17,43.17,43.17,43.17 +cra_lendingclub,80.82,76.33,57.34,80.82,47.01,53.07,74.02,76.44,79.64,75.88,79.3,80.12,80.42,80.34,80.82,97.51,93.98,97.4,89.15,93.57,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.38,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,80.82,77.11,80.82,80.82,76.25,80.82,80.79,80.82,80.71,79.67,80.82,79.67,79.26,79.38,80.82,80.82,80.82,80.82,78.86,80.82,80.82,80.01,80.82,80.82,79.38,80.82,80.79,80.82,80.82,80.82,80.82,80.82,80.82 +fiqasa,38.3,40.43,56.17,63.4,64.26,68.51,40.43,41.7,39.15,51.91,39.57,37.45,39.15,46.81,45.11,48.51,45.53,49.36,45.11,51.91,50.64,52.34,50.21,48.94,37.45,38.72,34.04,34.47,33.19,35.74,35.74,35.32,33.62,32.77,29.79,42.98,41.28,52.34,39.15,38.3,40.85,43.83,37.87,42.55,47.23,40.85,43.83,48.94,49.36,55.74,34.47,33.62,54.89,34.04,44.68,33.62,37.45,44.68,40.0,47.23,43.83,41.7,37.45,31.91,44.26,58.3,66.38,42.13,56.6,63.83,41.7,46.81,66.38,78.3,71.49,78.72,85.11,71.91,69.36,78.72,77.45 +fpb,76.08,32.78,30.72,70.72,78.35,78.04,32.37,33.09,34.64,54.85,32.27,33.09,35.05,42.78,74.85,72.16,74.54,73.51,75.05,73.4,82.58,84.43,85.57,84.43,74.02,73.81,75.46,76.19,75.15,74.12,76.19,75.77,78.56,78.04,76.6,75.57,75.36,71.96,75.57,75.67,74.74,74.74,75.57,74.54,73.92,75.46,74.33,74.33,71.55,71.86,75.77,75.77,70.52,75.98,72.78,75.88,75.98,76.8,75.88,72.16,76.19,73.3,76.39,75.67,74.54,72.37,70.52,74.95,73.92,72.06,75.36,74.85,67.94,36.91,40.52,29.38,28.66,40.1,36.49,29.38,30.72 +german,65.0,49.5,66.0,66.0,66.5,43.5,46.5,49.5,44.0,54.0,46.5,41.5,63.5,66.0,59.0,66.0,61.5,35.5,56.0,40.0,36.5,40.0,35.5,35.5,66.0,69.5,67.5,67.0,66.0,66.0,68.0,66.5,65.5,65.0,65.5,36.5,46.0,33.5,56.5,41.5,39.5,37.0,38.5,44.0,35.0,38.0,39.5,39.0,41.0,41.0,35.5,40.5,40.0,38.0,41.5,40.5,64.0,64.5,48.0,66.0,40.0,53.5,35.5,48.5,36.5,35.5,39.0,37.5,37.0,64.0,35.0,37.5,54.0,63.5,66.5,66.0,64.5,63.0,65.5,62.0,64.5 +headlines,74.81,59.95,59.95,62.96,77.84,77.53,59.95,59.95,59.95,59.95,59.95,59.95,59.95,59.95,72.19,73.97,74.39,71.99,73.01,73.84,77.51,78.25,78.1,77.94,72.25,71.8,74.55,74.15,72.1,69.66,74.32,74.08,93.21,95.87,97.76,68.58,72.91,64.07,73.58,73.95,73.26,73.16,73.85,72.9,72.31,73.25,73.66,72.72,72.09,72.02,74.66,74.54,68.28,74.33,71.64,73.96,74.08,71.52,72.68,71.25,73.78,70.8,74.78,75.08,71.57,68.77,65.3,71.07,69.57,65.69,70.54,71.41,67.05,59.95,59.95,59.95,59.95,59.95,59.95,59.95,59.95 +ner,21.75,0.62,9.01,17.89,9.36,9.52,8.13,3.45,7.53,23.06,7.86,10.15,14.12,17.88,29.23,25.64,21.96,28.88,24.59,25.09,24.66,23.5,21.93,22.62,27.86,24.71,23.79,26.38,28.97,26.64,21.69,25.28,23.77,24.8,23.36,57.82,47.07,59.42,46.84,21.87,24.57,24.15,22.13,23.28,25.37,24.52,21.59,25.82,23.48,26.59,19.54,22.61,30.18,21.76,25.79,21.15,26.27,20.12,27.12,26.94,26.15,26.99,28.8,25.14,23.29,27.98,26.94,21.0,26.68,25.79,24.25,26.87,24.39,26.46,19.39,8.26,3.0,14.61,5.68,9.14,8.55 +sm_acl,51.1,51.4,51.34,49.3,51.56,49.38,51.18,51.77,51.75,51.91,51.32,52.04,50.78,51.4,51.02,50.94,50.65,51.08,50.86,50.48,50.56,50.43,50.38,50.51,50.81,51.13,50.94,51.16,51.02,50.94,51.34,51.1,50.67,50.7,50.75,50.65,51.21,50.75,51.29,50.54,50.89,51.8,50.94,50.27,51.16,50.43,50.4,51.08,50.81,50.56,50.08,50.73,51.53,50.7,52.34,50.32,52.69,52.77,53.82,50.65,55.99,52.1,52.39,52.12,52.07,52.15,57.42,52.5,52.23,58.06,52.42,51.96,53.84,50.83,51.13,51.1,51.4,51.67,51.24,51.32,52.02 +sm_bigdata,55.3,55.57,52.79,51.02,50.27,47.76,55.23,53.6,54.35,55.3,55.91,55.64,55.57,54.62,55.16,54.96,55.16,55.23,55.1,55.23,55.84,56.05,55.71,55.91,55.43,55.37,55.37,55.37,55.03,55.1,55.37,55.64,54.96,55.84,55.77,55.5,55.37,55.3,55.57,53.94,52.11,51.29,54.01,52.51,52.45,54.08,51.56,51.43,51.97,50.07,53.8,53.6,51.97,54.08,52.99,54.48,44.63,45.65,47.08,44.16,47.15,46.94,50.14,50.75,49.18,57.54,51.09,49.59,50.75,54.28,51.15,53.74,54.28,55.03,55.3,53.53,53.74,54.14,53.67,52.65,54.14 +sm_cikm,58.44,54.24,54.07,44.01,58.27,47.86,55.03,55.64,54.77,54.94,55.82,55.64,54.07,53.98,58.53,58.71,58.36,58.27,58.71,58.53,58.27,57.74,58.18,57.48,58.36,58.18,58.62,58.36,58.62,57.92,58.79,58.79,58.01,58.01,57.92,58.09,58.79,57.39,58.62,57.13,56.87,56.43,58.09,56.87,56.43,56.34,56.43,54.16,57.22,56.52,56.26,56.52,53.37,56.96,49.96,56.34,43.57,46.81,50.22,41.82,53.81,53.98,51.62,53.54,56.17,52.93,51.88,54.16,52.49,52.58,53.72,56.26,51.97,52.41,54.86,53.89,55.64,50.31,52.32,54.68,56.87 +causal20_sc,65.14,88.48,79.45,83.75,76.17,87.16,86.9,87.36,86.79,84.26,86.62,86.91,86.4,86.93,71.27,70.2,76.73,69.17,70.9,69.51,83.63,83.98,84.21,84.86,71.12,71.02,69.92,72.47,76.96,78.55,75.57,75.09,85.16,89.74,90.94,67.66,68.88,56.43,69.82,73.64,72.3,69.89,70.19,70.86,69.15,71.99,68.15,66.47,76.32,78.01,74.01,71.36,70.77,67.9,68.0,73.22,70.39,67.14,68.18,66.91,70.18,70.41,73.41,77.2,67.37,59.23,60.3,66.49,61.07,53.43,63.77,68.41,59.03,74.95,70.83,67.57,71.15,77.13,76.18,71.7,74.72 +finarg_ecc_arc,64.78,46.67,60.0,62.32,63.04,44.64,49.28,47.39,43.91,47.83,47.39,44.35,47.54,44.78,65.36,65.65,65.07,65.8,64.93,64.64,65.07,63.91,63.62,64.64,64.64,65.07,65.51,64.78,65.07,65.51,66.09,65.65,62.03,56.81,54.78,64.78,64.06,63.77,64.93,65.07,66.96,65.22,64.93,66.96,65.65,65.22,65.65,65.22,66.96,68.26,65.36,65.65,66.09,65.22,66.67,65.8,67.39,67.25,68.12,67.54,65.8,67.54,67.1,67.1,68.26,68.84,68.55,68.12,67.39,69.57,68.41,66.96,69.13,56.81,53.04,56.38,66.67,56.23,56.38,55.65,59.28 +finarg_ecc_auc,48.3,51.81,49.85,55.01,61.71,65.02,52.53,52.01,52.63,52.73,53.15,52.43,52.53,52.22,53.87,56.04,53.97,55.52,55.42,55.11,58.1,56.76,58.62,58.82,51.5,51.91,51.19,54.08,54.39,54.18,54.7,54.28,49.23,48.4,48.4,55.11,50.46,52.43,49.23,53.25,53.04,52.84,52.43,52.12,52.32,54.28,52.43,50.88,52.94,51.81,48.71,49.23,50.57,48.3,50.15,49.85,47.88,49.43,48.5,47.68,48.81,47.57,48.81,49.23,52.22,51.91,50.05,54.7,52.43,50.98,51.7,49.95,47.99,51.6,50.77,49.23,51.91,48.92,50.36,46.34,52.43 +fomc,60.48,29.44,34.68,58.47,57.66,66.13,30.44,33.27,35.69,35.08,31.85,35.48,34.48,34.68,61.09,61.69,61.29,61.9,62.1,62.5,62.9,63.31,62.3,62.3,60.48,59.48,60.28,60.08,60.48,61.09,60.48,60.69,61.9,62.1,60.08,60.28,61.09,61.49,60.89,60.48,59.68,60.69,61.29,59.27,59.88,60.48,60.89,59.88,58.87,58.06,60.69,60.48,60.89,59.88,61.49,60.48,60.08,60.69,61.29,60.08,61.09,60.89,60.28,61.49,60.69,58.47,56.45,61.29,60.48,60.69,62.1,60.89,60.48,38.71,41.33,40.93,41.73,35.89,35.48,37.5,35.28 +ma,79.2,56.4,51.0,81.4,84.6,83.2,67.2,70.0,72.2,67.8,68.2,69.6,70.8,83.8,69.6,69.8,73.8,69.8,70.4,72.2,77.6,78.0,77.2,77.6,73.2,74.8,76.2,74.0,67.4,66.0,73.4,72.0,76.8,80.6,80.4,72.2,75.6,70.4,75.4,79.4,80.6,79.4,79.2,78.6,79.2,78.0,78.4,78.6,81.6,84.0,80.6,79.8,80.2,79.2,80.2,79.8,80.0,77.8,77.0,78.4,79.4,80.4,78.8,81.8,77.6,71.0,78.6,75.6,76.0,74.0,72.4,75.6,71.4,57.0,58.6,59.6,60.4,56.8,57.0,57.0,59.0 +mlesg,35.67,32.67,20.0,34.67,38.67,42.33,28.67,31.0,32.33,29.67,32.0,30.33,29.0,29.67,30.0,30.33,30.0,30.0,29.67,30.0,35.33,35.33,35.67,36.0,34.67,35.0,35.0,35.0,33.33,32.0,34.33,34.33,38.67,36.0,40.33,29.67,37.0,30.0,37.0,33.33,33.67,35.0,35.0,33.33,34.67,34.33,35.0,35.67,32.33,33.33,33.33,34.67,33.33,35.33,33.33,34.33,35.67,32.67,33.33,35.0,34.33,31.67,34.33,32.33,33.0,33.33,30.67,33.67,32.33,34.33,34.33,34.33,32.0,15.33,15.67,16.33,18.67,17.33,15.67,20.33,17.67 +multifin_en,60.99,31.32,28.39,65.38,63.55,68.5,29.85,30.77,27.66,32.6,29.3,32.42,28.94,31.32,61.17,61.54,61.36,60.26,60.07,60.26,61.9,64.29,64.1,63.55,60.62,61.72,61.72,60.62,60.99,61.17,60.81,60.26,60.26,57.88,57.69,61.72,60.81,59.71,60.81,60.26,60.62,60.44,60.44,60.44,60.44,59.71,60.62,61.54,59.71,58.24,60.81,61.54,58.97,61.9,58.79,60.44,60.81,62.09,61.17,59.34,61.17,59.89,61.54,62.45,61.54,60.26,58.06,61.36,61.17,61.72,61.17,60.62,60.26,32.23,35.53,34.98,33.52,30.77,32.23,32.23,29.3 +Avg,57.61,47.19,47.29,58.25,58.35,57.63,49.1,48.45,48.5,52.65,48.5,48.89,51.26,52.76,57.73,59.22,58.91,57.46,57.9,57.61,59.12,59.55,59.13,59.12,57.79,58.01,57.89,58.12,57.81,57.57,58.28,58.16,59.74,59.71,59.61,57.71,58.23,56.62,58.78,56.61,56.69,56.46,56.38,56.62,56.42,56.53,56.26,56.45,57.07,57.43,55.74,56.15,56.53,55.74,56.13,56.13,56.75,56.63,56.26,56.35,56.48,56.48,56.2,56.96,56.01,56.03,56.07,55.77,56.12,57.89,55.41,56.48,56.67,51.41,51.11,49.99,51.18,50.16,49.5,49.56,50.35 diff --git a/examples/dataset_llm_workflow/model_performance/math.csv b/examples/dataset_llm_workflow/model_performance/math.csv new file mode 100644 index 00000000..47781a63 --- /dev/null +++ b/examples/dataset_llm_workflow/model_performance/math.csv @@ -0,0 +1,18 @@ +Dataset,Qwen2.5-7B,Qwen1.5-110B,orca-math-word-problems-200k-1,orca-math-word-problems-200k-2,MWP-Instruct-1,GSM8K_zh-1,MATH_train-1,Arithmo-Data-1,MetaMathQA-1,MetaMathQA-2,MetaMath-GSM240K-1,school_math_0.25M-1,school_math_0.25M-2,MathInstruct-1,MathInstruct-2 +agieval_aqua_rat,41.73,38.98,40.94,41.73,41.34,38.98,39.37,38.98,40.55,40.55,41.73,38.19,38.98,39.37,40.55 +agieval_gaokao_mathcloze,16.95,38.14,12.71,13.56,11.86,7.63,17.8,13.56,5.93,5.08,14.41,9.32,6.78,9.32,17.8 +agieval_gaokao_mathqa,49.86,77.78,50.71,51.28,51.57,51.57,50.71,49.86,52.99,53.85,50.14,45.3,45.58,50.43,48.43 +agieval_math,19.8,19.3,19.8,17.2,20.2,20.6,19.6,17.0,28.7,28.1,21.1,16.7,16.9,18.2,20.1 +agieval_sat_math,55.91,57.27,57.73,57.27,57.27,55.45,50.0,55.45,55.45,57.27,55.0,54.55,56.82,55.0,55.0 +cmmlu_college_mathematics,45.71,47.62,50.48,46.67,51.43,52.38,45.71,49.52,48.57,47.62,49.52,50.48,52.38,47.62,47.62 +cmmlu_elementary_mathematics,65.65,77.83,65.22,64.78,64.35,66.96,65.65,65.65,67.39,66.96,64.78,54.35,56.09,66.09,65.22 +cmmlu_high_school_mathematics,61.59,77.44,64.63,64.02,60.98,61.59,62.8,63.41,64.02,62.8,64.63,56.71,57.32,64.02,64.63 +gsm8k,84.08,84.91,84.0,83.85,82.79,74.37,76.19,80.14,83.85,83.47,84.15,81.5,82.18,79.15,80.97 +mathqa,43.32,48.07,45.63,46.93,42.61,36.65,41.64,41.41,42.38,42.38,41.17,40.13,40.57,40.94,41.27 +mgsm_native_cot_zh,66.4,68.8,67.6,70.0,66.0,73.6,66.4,68.0,68.0,68.8,71.6,61.2,57.2,65.6,68.0 +minerva_math,40.16,47.9,40.9,41.56,39.42,29.64,41.96,45.12,32.2,30.32,36.48,43.14,42.4,27.58,29.94 +abstract_algebra,54.0,53.0,52.0,52.0,55.0,56.0,55.0,53.0,56.0,54.0,54.0,51.0,52.0,52.0,54.0 +college_mathematics,53.0,52.0,52.0,55.0,53.0,51.0,58.0,53.0,53.0,55.0,56.0,48.0,51.0,53.0,56.0 +elementary_mathematics,72.75,78.84,72.75,73.28,75.13,74.07,73.54,75.13,73.81,73.02,73.02,70.9,71.96,73.54,74.34 +high_school_mathematics,55.93,60.0,55.19,55.19,55.56,55.93,55.93,55.93,57.04,56.67,55.56,50.74,51.11,54.81,55.19 +Avg,51.68,57.99,52.02,52.14,51.78,50.4,51.27,51.57,51.87,51.62,52.08,48.26,48.7,49.79,51.19 diff --git a/examples/dataset_llm_workflow/model_performance/medical.csv b/examples/dataset_llm_workflow/model_performance/medical.csv new file mode 100644 index 00000000..7f557ef7 --- /dev/null +++ b/examples/dataset_llm_workflow/model_performance/medical.csv @@ -0,0 +1,11 @@ +Dataset,Qwen2.5-7B,Flan-PaLM-540B,medqa_train&pubmed_causal-1,medqa_train-1,pubmed_causal-1,medalpaca_cleaned-1,medqa_train&medmcqa_train-1,medmcqa_train-1,AlpaCare-1,ChatDoctor-1,ChatDoctor-2,AlpaCare&ChatDoctor-1,AlpaCare&ChatDoctor-2,medalpaca_cleaned&AlpaCare&ChatDoctor-1,medalpaca_cleaned&AlpaCare&ChatDoctor-2 +medmcqa,59.93,57.6,59.48,59.48,60.46,59.81,62.49,62.01,59.77,60.29,60.15,58.93,58.38,59.72,59.55 +medqa_4options,64.18,67.6,65.59,65.59,63.16,63.86,64.81,63.63,62.92,63.63,63.32,62.14,61.67,62.61,62.37 +anatomy,71.85,63.7,71.85,71.85,71.85,70.37,70.37,71.11,71.85,72.59,73.33,70.37,70.37,70.37,71.11 +clinical_knowledge,77.36,80.4,77.36,77.74,78.49,78.87,78.49,79.25,78.49,77.74,76.6,78.49,78.11,78.11,77.74 +college_biology,82.64,88.9,86.11,84.72,83.33,84.72,84.03,85.42,84.03,84.03,81.94,82.64,82.64,84.72,86.11 +college_medicine,69.36,76.3,68.79,69.94,69.36,69.94,68.79,68.21,68.79,67.05,67.63,69.36,68.79,68.79,71.1 +medical_genetics,87.0,75.0,87.0,88.0,87.0,85.0,89.0,89.0,87.0,86.0,88.0,86.0,87.0,85.0,83.0 +professional_medicine,78.68,83.8,76.84,79.78,76.47,77.94,78.68,76.47,77.21,77.57,77.21,75.74,76.1,77.21,76.84 +pubmedqa,75.2,79.0,76.0,75.8,76.4,75.8,76.8,75.8,75.0,74.8,73.8,74.8,75.0,76.2,75.6 +Avg,74.02,74.7,74.34,74.77,74.06,74.03,74.83,74.54,73.9,73.74,73.55,73.16,73.12,73.64,73.71 diff --git a/examples/dataset_llm_workflow/workflow.py b/examples/dataset_llm_workflow/workflow.py new file mode 100644 index 00000000..0b1bbfff --- /dev/null +++ b/examples/dataset_llm_workflow/workflow.py @@ -0,0 +1,399 @@ +import fire +import time +import tempfile +import os +import pandas as pd +import json +import re +import numpy as np +import matplotlib.pyplot as plt +import lm_eval +from lm_eval.models.huggingface import HFLM + +from learnware.client import LearnwareClient +from learnware.logger import get_module_logger +from learnware.market import BaseUserInfo, instantiate_learnware_market +from learnware.specification import GenerativeModelSpecification + +from benchmark import Benchmark +from eval_config import CONFIG + +logger = get_module_logger("llm_workflow", level="INFO") + + +class LLMWorkflow: + def _plot_radar_chart(self, benchmark_name, results_table): + labels = list(results_table.index) + if benchmark_name == "finance": + column_split = [ + ["Learnware", "Qwen2.5-7B", "Llama3.1-8B-Instruct", "Llama3.1-8B"], + ["Learnware", "Qwen1.5-110B", "Qwen2.5-72B", "Llama3.1-70B-Instruct"], + ["Learnware", "Random", "Best-single", "Oracle"], + ] + YTICKS = [0.2, 0.4, 0.6, 0.8, 1.0] + ylim = (0, 1.15) + x_label_fontsize = 4.5 + labels = [ + "Australian", + "LendingClub", + "FiQA-SA", + "FPB", + "German", + "Headlines", + "NER", + "ACL18", + "BigData22", + "CIKM18", + "SC", + "FinArg-ARC", + "FinArg-ACC", + "FOMC", + "MA", + "MLESG", + "MultiFin", + ] + elif benchmark_name == "math": + column_split = [ + ["Learnware", "Qwen2.5-7B"], + ["Learnware", "Qwen1.5-110B"], + ["Learnware", "Random", "Best-single", "Oracle"], + ] + YTICKS = [0.4, 0.6, 0.8, 1.0] + ylim = (0.3, 1.3) + x_label_fontsize = 5 + elif benchmark_name == "medical": + column_split = [ + ["Learnware", "Qwen2.5-7B"], + ["Learnware", "Flan-PaLM-540B"], + ["Learnware", "Random", "Best-single", "Oracle"], + ] + YTICKS = [0.8, 0.9, 1.0] + ylim = (0.75, 1.1) + x_label_fontsize = 8 + + num_vars = len(labels) + + angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() + angles += angles[:1] + + fig, axes = plt.subplots(1, 3, figsize=(16, 5), subplot_kw=dict(polar=True)) + + model_names = ["Learnware vs Base Model", "Learnware vs Large-scale Model", "Specialized SLMs"] + + colors = [ + np.array([0.9, 0.17, 0.31]), + np.array([1.0, 0.49, 0.0]), + np.array([0.19, 0.55, 0.91]), + np.array([0.56, 0.74, 0.56]), + np.array([0.66, 0.66, 0.66]), + ] + + for i, (ax, model_name) in enumerate(zip(axes, model_names)): + ax.set_xticks(angles[:-1]) + ax.set_yticks(YTICKS) + ax.set_xticklabels(labels, fontsize=x_label_fontsize, rotation=30) + ax.set_yticklabels([str(y) for y in YTICKS]) + ax.set_ylim(ylim[0], ylim[1]) + ax.set_title(model_name, pad=30) + + methods = column_split[i] + + for i, (method, color) in enumerate(zip(methods, colors[: len(methods)])): + if i == 0: + zorder = 2 + else: + zorder = 1 + + values = (results_table[method] / results_table["Oracle"]).tolist() + values += values[:1] + + ax.plot(angles, values, color=color, linewidth=2, label=method, zorder=zorder) + ax.fill(angles, values, color=color, alpha=0.1, zorder=zorder) + + ax.legend(loc="lower left", fontsize=8, bbox_to_anchor=(0.85, 0.9)) + + plt.tight_layout() + os.makedirs("results/figs", exist_ok=True) + plt.savefig(f"results/figs/llm-{benchmark_name}.pdf") + + def _anlysis_table(self, benchmark_name, table, score_results): + if benchmark_name == "finance": + start_column_id = 7 + else: # math / medical + start_column_id = 3 + table = table[:-1] + performance = table.melt( + id_vars=["Dataset"], value_vars=table.columns[start_column_id:], var_name="Source_Config" + ) + performance_extra = table.iloc[:, :start_column_id] + performance = pd.concat( + [performance, performance["Source_Config"].str.extract(r"(.+)-(\d+)").rename(columns={0: "Learnware"})], + axis=1, + ) + performance["Learnware"] = performance["Learnware"].apply(lambda s: s[:-1] if s[-1] == "-" else s) + performance = performance.rename(columns={"Dataset": "User"}) + performance.drop(columns=[1], inplace=True) + perf_merged = performance[["User", "Learnware", "value"]].groupby(["Learnware", "User"]).mean().reset_index() + + performance_extra = performance_extra.rename(columns={"Dataset": "User"}) + performance_extra = performance_extra.set_index("User") + + score_results = pd.DataFrame(score_results) + score_results["Rank-PAVE"] = ( + score_results.groupby("User")["Similarity"].rank(method="min", ascending=False).astype(int) - 1 + ) + adaptation_info = pd.merge(score_results, perf_merged, on=["Learnware", "User"]) + random_value = (adaptation_info[["User", "value"]].groupby(["User"]).mean()).rename(columns={"value": "Random"}) + oracle_value = (adaptation_info[["User", "value"]].groupby(["User"]).max()).rename(columns={"value": "Oracle"}) + pave_value = ( + adaptation_info[adaptation_info["Rank-PAVE"] < 1][["User", "value"]].groupby(["User"]).mean() + ).rename(columns={"value": "Learnware"}) + + # Best-single + perf_pivot = perf_merged.pivot(index="User", columns="Learnware", values="value") + best_column = perf_pivot.mean().idxmax() + best_single = perf_pivot[[best_column]].rename(columns={best_column: "Best-single"}) + + adaptation_table = pd.concat([random_value, pave_value, best_single, oracle_value], axis=1) + + # join performance_extra + adaptation_table = performance_extra.join(adaptation_table) + + # Avg Rank + ranks = adaptation_table.rank(axis=1, method="min", ascending=False) + avg_rank = ranks.mean() + + # PAVE win/tie/loss + pave_scores = adaptation_table["Learnware"] + win_tie_loss = {} + + for col in adaptation_table.columns: + if col == "Learnware": + continue + win = (pave_scores > adaptation_table[col]).sum() + tie = (pave_scores == adaptation_table[col]).sum() + loss = (pave_scores < adaptation_table[col]).sum() + win_tie_loss[col] = f"{win}/{tie}/{loss}" + + # Oracle win/tie/loss + oracle_scores = adaptation_table["Oracle"] + win_tie_loss_o = {} + + for col in adaptation_table.columns: + if col == "Oracle": + continue + win = (oracle_scores > adaptation_table[col]).sum() + tie = (oracle_scores == adaptation_table[col]).sum() + loss = (oracle_scores < adaptation_table[col]).sum() + win_tie_loss_o[col] = f"{win}/{tie}/{loss}" + + adaptation_table.loc["Avg."] = adaptation_table.mean() + adaptation_table.loc["Avg. rank"] = avg_rank + adaptation_table = adaptation_table.round(2) + adaptation_table.loc["Learnware (win/tie/loss)"] = win_tie_loss + adaptation_table.loc["Oracle (win/tie/loss)"] = win_tie_loss_o + + print(adaptation_table.to_markdown()) + os.makedirs("results/tables", exist_ok=True) + adaptation_table.to_csv(f"results/tables/llm-{benchmark_name}.csv") + + return adaptation_table + + def _prepare_market(self, benchmark: Benchmark, rebuild=False): + client = LearnwareClient() + self.llm_benchmark = benchmark + self.llm_market = instantiate_learnware_market( + market_id=f"llm_{self.llm_benchmark.name}", name="llm", rebuild=rebuild + ) + self.user_semantic = client.get_semantic_specification(self.llm_benchmark.learnware_ids[0]) + self.user_semantic["Name"]["Values"] = "" + self.user_semantic["Description"]["Values"] = "" + self.user_semantic["License"]["Values"] = ["Apache-2.0", "Others"] + + if len(self.llm_market) == 0 or rebuild is True: + for learnware_id in self.llm_benchmark.learnware_ids: + with tempfile.TemporaryDirectory(prefix="llm_benchmark_") as tempdir: + zip_path = os.path.join(tempdir, f"{learnware_id}.zip") + for i in range(20): + try: + semantic_spec = client.get_semantic_specification(learnware_id) + client.download_learnware(learnware_id, zip_path) + self.llm_market.add_learnware(zip_path, semantic_spec) + break + except Exception: + time.sleep(1) + continue + + logger.info("Total Item: %d" % (len(self.llm_market))) + + def build_specification_and_cache(self, name, saved_folder, benchmark: Benchmark): + generative_spec = GenerativeModelSpecification() + generative_spec_path = os.path.join(saved_folder, name, "generative.pth") + + os.makedirs(os.path.join(saved_folder, name), exist_ok=True) + + if os.path.exists(generative_spec_path): + generative_spec.load(generative_spec_path) + else: + train_dataset = benchmark.get_user_dataset(name) + generative_spec.generate_stat_spec_from_data(dataset=train_dataset) + generative_spec.save(generative_spec_path) + + return generative_spec + + def _get_scores(self, benchmark_name, base_model: str, adapter_path, batch_size="auto"): + benchmark_configs = CONFIG[benchmark_name] + task_manager = lm_eval.tasks.TaskManager() + task_names = [config.name for config in benchmark_configs] + + if benchmark_name == "medical": + lm_obj = HFLM(pretrained=base_model, peft=adapter_path, batch_size=batch_size) + results = lm_eval.simple_evaluate( + model=lm_obj, + tasks=task_names, + task_manager=task_manager, + ) + else: + results_dir = f"./eval_results/{benchmark_name}" + adapter_id = adapter_path.split("/")[-2] if adapter_path else None + task_names_str = ",".join(task_names) + if adapter_path: + os.system( + f"CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch -m lm_eval --model hf \ + --model_args pretrained={base_model},peft={adapter_path} \ + --tasks {task_names_str} \ + --batch_size {batch_size} \ + --output_path ./eval_results/{benchmark_name}" + ) + elif base_model in ["Qwen/Qwen1.5-110B", "Qwen/Qwen2.5-72B", "NousResearch/Meta-Llama-3.1-70B-Instruct"]: + os.system( + f"CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --num_processes 1 -m lm_eval --model hf \ + --model_args pretrained={base_model},parallelize=True \ + --tasks {task_names_str} \ + --batch_size {batch_size} \ + --output_path ./eval_results/{benchmark_name}" + ) + else: + os.system( + f"CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch -m lm_eval --model hf \ + --model_args pretrained={base_model} \ + --tasks {task_names_str} \ + --batch_size {batch_size} \ + --output_path ./eval_results/{benchmark_name}" + ) + + if adapter_id: + for dir_name in os.listdir(results_dir): + if adapter_id in dir_name: + results_dir_path = os.path.join(results_dir, dir_name) + results_path = os.path.join(results_dir_path, sorted(os.listdir(results_dir_path))[-1]) + break + else: + for dir_name in os.listdir(results_dir): + if dir_name == base_model.replace("/", "__"): + results_dir_path = os.path.join(results_dir, dir_name) + results_path = os.path.join(results_dir_path, sorted(os.listdir(results_dir_path))[-1]) + break + + with open(results_path, "r", encoding="utf-8") as f: + results = json.load(f) + + score_list = [] + for config in benchmark_configs: + score = results["results"][config.name][f"{config.eval_metric},none"] * 100 + score = round(score, 2) + logger.info(f"Name: {config.name}, Score: {score}") + score_list.append(score) + + return score_list + + def llm_example(self, benchmark_name, rebuild=False, skip_eval=True): + benchmark = Benchmark(benchmark_name) + self._prepare_market(benchmark, rebuild) + user_names = benchmark.get_user_names() + + score_results = {"User": [], "Learnware": [], "Similarity": []} + + for name in user_names: + title = "=" * 20 + name + "=" * 20 + print(title) + + generative_spec = self.build_specification_and_cache(name, "user_specs", benchmark) + + user_info = BaseUserInfo( + semantic_spec=self.user_semantic, stat_info={"GenerativeModelSpecification": generative_spec} + ) + logger.info(f"Searching Market for user: {name}") + + search_result = self.llm_market.search_learnware(user_info) + single_result = search_result.get_single_results() + + scores = {} + for result in single_result: + learnware_name = result.learnware.specification.semantic_spec["Name"]["Values"] + match = re.match(r"(.+)-(\d+)", learnware_name) + dataset_name = match.group(1) + scores[dataset_name] = result.score + + for k, v in scores.items(): + score_results["User"].append(name) + score_results["Learnware"].append(k) + score_results["Similarity"].append(v) + + if not skip_eval: + all_learnwares_ids = self.llm_market.get_learnware_ids() + if benchmark_name == "medical": + performance_table = { + "Qwen2.5-7B": self._get_scores(benchmark_name, "Qwen/Qwen2.5-7B", None), + "Flan-PaLM-540B": [ + 57.60, + 67.60, + 63.70, + 80.40, + 88.90, + 76.30, + 75.00, + 83.80, + 79.00, + ], # copied from Open Medical LLM Leaderboard + } + elif benchmark_name == "math": + performance_table = { + "Qwen2.5-7B": self._get_scores(benchmark_name, "Qwen/Qwen2.5-7B", None), + "Qwen1.5-110B": self._get_scores(benchmark_name, "Qwen/Qwen1.5-110B", None), + } + elif benchmark_name == "finance": + performance_table = { + "Qwen2.5-7B": self._get_scores(benchmark_name, "Qwen/Qwen2.5-7B", None), + "Llama3.1-8B-Instruct": self._get_scores( + benchmark_name, "NousResearch/Meta-Llama-3.1-8B-Instruct", None + ), + "Llama3.1-8B": self._get_scores(benchmark_name, "NousResearch/Meta-Llama-3.1-8B", None), + "Qwen1.5-110B": self._get_scores(benchmark_name, "Qwen/Qwen1.5-110B", None), + "Qwen2.5-72B": self._get_scores(benchmark_name, "Qwen/Qwen2.5-72B", None), + "Llama3.1-70B-Instruct": self._get_scores( + benchmark_name, "NousResearch/Meta-Llama-3.1-70B-Instruct", None + ), + } + + for learnware_id in all_learnwares_ids: + learnware = self.llm_market.get_learnware_by_ids(learnware_id) + base_model = learnware.specification.semantic_spec["Description"]["Values"].split(" ")[-1] + adapter_path = os.path.join(self.llm_market.get_learnware_dir_path_by_ids(learnware_id), "adapter") + score_list = self._get_scores(benchmark_name, base_model, adapter_path) + performance_table[learnware.specification.semantic_spec["Name"]["Values"]] = score_list + + performance_table = pd.DataFrame(performance_table) + performance_table = performance_table._append(performance_table.mean().round(2), ignore_index=True) + datasets = benchmark.get_user_names() + performance_table.insert(0, "Dataset", datasets + ["Avg"]) + performance_table.to_csv(f"model_performance/{benchmark_name}-new.csv", index=False) + else: + performance_table = pd.read_csv(f"model_performance/{benchmark_name}.csv") + + results_table = self._anlysis_table(benchmark_name, performance_table, score_results) + self._plot_radar_chart(benchmark_name, results_table[:-4]) + + +if __name__ == "__main__": + fire.Fire(LLMWorkflow) diff --git a/examples/dataset_table_workflow/base.py b/examples/dataset_table_workflow/base.py index 6f6559cd..80848dfc 100644 --- a/examples/dataset_table_workflow/base.py +++ b/examples/dataset_table_workflow/base.py @@ -14,7 +14,7 @@ from learnware.logger import get_module_logger from learnware.market import instantiate_learnware_market from learnware.reuse.utils import fill_data_with_mean -from learnware.tests.benchmarks import LearnwareBenchmark +from learnware.tests.benchmarks import LearnwareBenchmarkManager logger = get_module_logger("base_table", level="INFO") @@ -63,18 +63,20 @@ def get_train_subsets(n_labeled_list, n_repeat_list, train_x, train_y): def _prepare_market(self, benchmark_config, name, rebuild, retrain): client = LearnwareClient() - self.benchmark = LearnwareBenchmark().get_benchmark(benchmark_config) + self.benchmark = LearnwareBenchmarkManager().get_benchmark(benchmark_config) self.market = instantiate_learnware_market( market_id=self.benchmark.name, name=name, rebuild=rebuild, - organizer_kwargs={ - "auto_update": True, - "auto_update_limit": len(self.benchmark.learnware_ids), - **market_mapping_params, - } - if retrain - else None, + organizer_kwargs=( + { + "auto_update": True, + "auto_update_limit": len(self.benchmark.learnware_ids), + **market_mapping_params, + } + if retrain + else None + ), ) self.user_semantic = client.get_semantic_specification(self.benchmark.learnware_ids[0]) self.user_semantic["Name"]["Values"] = "" diff --git a/examples/dataset_text_workflow/workflow.py b/examples/dataset_text_workflow/workflow.py index 42bc315f..04e9ef70 100644 --- a/examples/dataset_text_workflow/workflow.py +++ b/examples/dataset_text_workflow/workflow.py @@ -17,7 +17,7 @@ from learnware.market import BaseUserInfo, instantiate_learnware_market from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser from learnware.specification import RKMETextSpecification -from learnware.tests.benchmarks import LearnwareBenchmark +from learnware.tests.benchmarks import LearnwareBenchmarkManager logger = get_module_logger("text_workflow", level="INFO") @@ -72,7 +72,7 @@ def _plot_labeled_peformance_curves(self, all_user_curves_data): def _prepare_market(self, rebuild=False): client = LearnwareClient() - self.text_benchmark = LearnwareBenchmark().get_benchmark(text_benchmark_config) + self.text_benchmark = LearnwareBenchmarkManager().get_benchmark(text_benchmark_config) self.text_market = instantiate_learnware_market(market_id=self.text_benchmark.name, rebuild=rebuild) self.user_semantic = client.get_semantic_specification(self.text_benchmark.learnware_ids[0]) self.user_semantic["Name"]["Values"] = "" diff --git a/learnware/__init__.py b/learnware/__init__.py index 97e81afd..85ae6577 100644 --- a/learnware/__init__.py +++ b/learnware/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.3.2.99" +__version__ = "0.4.0.post1" import json import os diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 2be5e550..4249917e 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -52,24 +52,50 @@ class SemanticSpecificationKey(Enum): DATA_TYPE = "Data" TASK_TYPE = "Task" LIBRARY_TYPE = "Library" + MODEL_TYPE = "Model" SENARIOES = "Scenario" LICENSE = "License" class LearnwareClient: - def __init__(self, host=None): + def __init__(self, host=None, timeout=None): self.headers = None if host is None: - self.host = C.backend_host + host = os.environ.get("LEARNWARE_BACKEND_HOST") + if host is None: + self.host = C.backend_host + pass + else: + self.host = host + pass + pass else: self.host = host + pass self.chunk_size = 1024 * 1024 self.tempdir_list = [] self.login_status = False + if timeout is None: + self.timeout = 60 + else: + self.timeout = timeout atexit.register(self.cleanup) + self.storage_path = os.environ.get("LEARNWARE_STORAGE_PATH") + if self.storage_path is None: + self.storage_path = os.path.join(os.path.expanduser("~"), ".learnware", "default", "learnware_pool") + pass + self.default_zip_path = os.path.join(self.storage_path, "zips") + self.default_unzip_path = os.path.join(self.storage_path, "unzipped_learnwares") + if not os.path.exists(self.default_zip_path): + os.makedirs(self.default_zip_path, exist_ok=True) + pass + if not os.path.exists(self.default_unzip_path): + os.makedirs(self.default_unzip_path, exist_ok=True) + pass + def is_connected(self): url = f"{self.host}/auth/login_by_token" response = requests.post(url) @@ -80,8 +106,7 @@ def is_connected(self): def login(self, email, token): url = f"{self.host}/auth/login_by_token" - response = requests.post(url, json={"email": email, "token": token}) - + response = requests.post(url, json={"email": email, "token": token}, timeout=self.timeout) result = response.json() if result["code"] != 0: raise Exception("login failed: " + json.dumps(result)) @@ -189,7 +214,11 @@ def get_semantic_specification(self, learnware_id: str): return result["data"]["learnware_info"]["semantic_specification"] - def download_learnware(self, learnware_id: str, save_path: str): + def download_learnware(self, learnware_id: str, save_path: str = None): + if save_path is None: + save_path = os.path.join(self.default_zip_path, learnware_id + ".zip") + pass + url = f"{self.host}/engine/download_learnware" response = requests.get( @@ -202,7 +231,7 @@ def download_learnware(self, learnware_id: str, save_path: str): ) if response.status_code != 200: - raise Exception("download failed: " + json.dumps(response.json())) + raise Exception("download failed: " + response.text) num_chunks = int(response.headers["Content-Length"]) // CHUNK_SIZE + 1 bar = tqdm(total=num_chunks, desc="Downloading", unit="MB") @@ -272,6 +301,7 @@ def search_learnware(self, user_info: BaseUserInfo, page_size=10, page_index=0): "page": page_index, }, headers=self.headers, + timeout=self.timeout, ) result = response.json() if result["code"] != 0: @@ -309,6 +339,43 @@ def list_semantic_specification_values(self, key: SemanticSpecificationKey): semantic_conf = result["data"]["semantic_specification"] return semantic_conf[key.value]["Values"] + def get_pretrained_path(self, learnware_id: str): + # get pretrained path from learnware id + + # check learnware exists + if os.path.exists(os.path.join(self.default_unzip_path, learnware_id)): + pass + else: + # learnware not exist + if not os.path.exists(os.path.join(self.default_zip_path, learnware_id + ".zip")): + self.download_learnware(learnware_id) + pass + else: + # learnware exists + pass + self.unzip_learnware(learnware_id) + pass + + yaml_file = os.path.join(self.default_unzip_path, learnware_id, C.learnware_folder_config["yaml_file"]) + with open(yaml_file, "r") as fin: + learnware_info = yaml.safe_load(fin) + pass + pretrained_path = learnware_info["model"].get("weights_file_path") + if pretrained_path is None: + raise FileNotFoundError(f"Pretrained path not found in learnware {learnware_id}") + + return os.path.join(self.default_unzip_path, learnware_id, pretrained_path) + pass + + def unzip_learnware(self, learnware_id: str): + if not os.path.exists(os.path.join(self.default_zip_path, learnware_id + ".zip")): + raise FileNotFoundError(f"Learnware {learnware_id} not found") + else: + with zipfile.ZipFile(os.path.join(self.default_zip_path, learnware_id + ".zip"), "r") as z_file: + z_file.extractall(os.path.join(self.default_unzip_path, learnware_id)) + pass + pass + def load_learnware( self, learnware_path: Optional[Union[str, List[str]]] = None, @@ -425,7 +492,7 @@ def check_learnware(learnware_zip_path, semantic_specification=None): name="test", description="test", data_type="Text", - task_type="Segmentation", + task_type="Text Generation", scenarios="Financial", library_type="Scikit-learn", license="Apache-2.0", @@ -440,7 +507,7 @@ def check_learnware(learnware_zip_path, semantic_specification=None): with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: with zipfile.ZipFile(learnware_zip_path, mode="r") as z_file: z_file.extractall(tempdir) - + pass learnware = get_learnware_from_dirpath( id="test", semantic_spec=semantic_specification, learnware_dirpath=tempdir, ignore_error=False ) diff --git a/learnware/client/utils.py b/learnware/client/utils.py index fc96c01d..0bbeefc2 100644 --- a/learnware/client/utils.py +++ b/learnware/client/utils.py @@ -8,7 +8,7 @@ logger = get_module_logger(module_name="client_utils") -def system_execute(args, timeout=None, env=None, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE): +def system_execute(args, timeout=None, env=None, stdout=None, stderr=None): env = os.environ.copy() if env is None else env args = args if isinstance(args, str) else " ".join(args) @@ -92,6 +92,8 @@ def install_environment(learnware_dirpath, conda_env, conda_prefix=None): raise Exception("Environment.yaml or requirements.txt not found in the learnware folder.") logger.info(f"install learnware package for conda env [{conda_env}]") + learnware_package = os.environ.get("LEARNWARE_PACKAGE_LOCATION", "learnware") + system_execute( args=[ "conda", @@ -104,6 +106,6 @@ def install_environment(learnware_dirpath, conda_env, conda_prefix=None): "-m", "pip", "install", - "learnware", + learnware_package, ] ) diff --git a/learnware/config.py b/learnware/config.py index a94e3b6f..4e3e9601 100644 --- a/learnware/config.py +++ b/learnware/config.py @@ -92,10 +92,20 @@ def get_platform(): "Feature Extraction", "Segmentation", "Object Detection", + "Text Generation", "Others", ], "Type": "Class", # Choose only one class }, + "Model": { + "Values": [ + "Base Model", + "Fine-tuned Model", + "PEFT Model", + "Others", + ], + "Type": "Optional", + }, "Library": { "Values": ["Scikit-learn", "PyTorch", "TensorFlow", "Others"], "Type": "Class", @@ -162,6 +172,7 @@ def get_platform(): "learnware_folder_config": { "yaml_file": "learnware.yaml", "module_file": "__init__.py", + "weights_file_path": "weights", }, "database_url": f"sqlite:///{DATABASE_PATH}", "max_reduced_set_size": 1310720, diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index 60996a75..78942176 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -35,16 +35,11 @@ def get_learnware_from_dirpath( learnware_config = { "model": { "class_name": "Model", + "weights_file_path": "weights", + "required_learnware_ids": [], "kwargs": {}, }, - "stat_specifications": [ - { - "module_path": "learnware.specification", - "class_name": "RKMETableSpecification", - "file_name": "stat_spec.json", - "kwargs": {}, - }, - ], + "stat_specifications": [], } try: @@ -65,6 +60,23 @@ def get_learnware_from_dirpath( if "module_path" not in learnware_config["model"]: learnware_config["model"]["module_path"] = C.learnware_folder_config["module_file"] + if semantic_spec["Data"]["Values"] == ["Text"] and semantic_spec["Task"]["Values"] == ["Text Generation"]: + if "weights_file_path" not in learnware_config["model"]: + learnware_config["model"]["weights_file_path"] = C.learnware_folder_config["weights_file_path"] + + learnware_weights_path = os.path.join(learnware_dirpath, learnware_config["model"]["weights_file_path"]) + assert os.path.exists( + learnware_weights_path + ), f"Weights are not found for the Text Generation Model learnware_{id}, please check the learnware.yaml or zipfile." + + if semantic_spec["Model"]["Values"] == ["PEFT Model"]: + assert ( + "required_learnware_ids" in learnware_config["model"] + ), f"'required_learnware_ids' is not found for the PEFT Model learnware_{id}, please check the learnware.yaml." + assert ( + len(learnware_config["model"]["required_learnware_ids"]) != 0 + ), f"'required_learnware_ids' can't be empty for the PEFT Model learnware_{id}, please check the learnware.yaml." + learnware_spec = Specification() for _stat_spec in learnware_config["stat_specifications"]: stat_spec = _stat_spec.copy() diff --git a/learnware/market/__init__.py b/learnware/market/__init__.py index 6e9e718e..760884ca 100644 --- a/learnware/market/__init__.py +++ b/learnware/market/__init__.py @@ -1,10 +1,19 @@ from .anchor import AnchoredOrganizer, AnchoredSearcher, AnchoredUserInfo from .base import BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo, LearnwareMarket from .classes import CondaChecker -from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker +from .easy import ( + EasyExactSemanticSearcher, + EasyFuzzSemanticSearcher, + EasyOrganizer, + EasySemanticChecker, + EasyStatChecker, + EasyStatSearcher, + SeqCombinedSearcher, +) from .evolve import EvolvedOrganizer from .evolve_anchor import EvolvedAnchoredOrganizer -from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher +from .heterogeneous import HeteroMapTableOrganizer, HeteroStatSearcher +from .llm import LLMEasyOrganizer, LLMStatSearcher from .module import instantiate_learnware_market __all__ = [ @@ -18,12 +27,17 @@ "LearnwareMarket", "CondaChecker", "EasyOrganizer", - "EasySearcher", + "EasyExactSemanticSearcher", + "EasyFuzzSemanticSearcher", + "EasyStatSearcher", + "SeqCombinedSearcher", "EasySemanticChecker", "EasyStatChecker", "EvolvedOrganizer", "EvolvedAnchoredOrganizer", "HeteroMapTableOrganizer", - "HeteroSearcher", + "HeteroStatSearcher", + "LLMEasyOrganizer", + "LLMStatSearcher", "instantiate_learnware_market", ] diff --git a/learnware/market/anchor/searcher.py b/learnware/market/anchor/searcher.py index c8fe2a26..5d6eef22 100644 --- a/learnware/market/anchor/searcher.py +++ b/learnware/market/anchor/searcher.py @@ -1,14 +1,14 @@ from typing import Any, List, Tuple from .user_info import AnchoredUserInfo -from ..easy.searcher import EasySearcher +from ..base import AtomicSearcher from ...learnware import Learnware from ...logger import get_module_logger logger = get_module_logger("anchor_searcher") -class AnchoredSearcher(EasySearcher): +class AnchoredSearcher(AtomicSearcher): def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: """Search anchor Learnwares from anchor_learnware_list based on user_info diff --git a/learnware/market/base.py b/learnware/market/base.py index 3b53798c..2b87f564 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -203,6 +203,8 @@ def search_learnware(self, user_info: BaseUserInfo, check_status: int = None, ** SearchResults Search results """ + # searcher = self.searcher_selector.select_searcher(user_info) + # return searcher(user_info, check_status, **kwargs) return self.learnware_searcher(user_info, check_status, **kwargs) def delete_learnware(self, id: str, **kwargs) -> bool: @@ -501,6 +503,54 @@ def __call__(self, user_info: BaseUserInfo, check_status: int = None) -> SearchR raise NotImplementedError("'__call__' method is not implemented in BaseSearcher") +class AtomicSearcher(BaseSearcher): + def __init__(self, organizer: BaseOrganizer, **kwargs): + super(AtomicSearcher, self).__init__(organizer, **kwargs) + + def is_applicable_user(self, user_info: BaseUserInfo, **kwargs) -> bool: + """Check if the user_info is applicable for this searcher + + Parameters + ---------- + user_info : BaseUserInfo + user_info contains semantic_spec and stat_info + + Returns + ------- + bool + A flag indicating whether the user_info is applicable for this searcher + """ + raise NotImplementedError("'is_applicable_user' method is not implemented in AtomicSearcher") + + def is_applicable_learnware(self, learnware: Learnware, **kwargs) -> bool: + """Check if the learnware is applicable for this searcher + + Parameters + ---------- + learnware : Learnware + learnware to be checked + + Returns + ------- + bool + A flag indicating whether the learnware is applicable for this searcher + """ + raise NotImplementedError("'is_applicable_learnware' method is not implemented in AtomicSearcher") + + def __call__(self, user_info: BaseUserInfo, check_status: int = None) -> SearchResults: + """Search learnwares based on user_info from learnwares with check_status + + Parameters + ---------- + user_info : BaseUserInfo + user_info contains semantic_spec and stat_info + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status + """ + raise NotImplementedError("'__call__' method is not implemented in AtomicSearcher") + + class BaseChecker: INVALID_LEARNWARE = -1 NONUSABLE_LEARNWARE = 0 diff --git a/learnware/market/easy/__init__.py b/learnware/market/easy/__init__.py index bbedeefb..1ec3058f 100644 --- a/learnware/market/easy/__init__.py +++ b/learnware/market/easy/__init__.py @@ -5,16 +5,16 @@ logger = get_module_logger("market_easy") if not is_torch_available(verbose=False): - EasySearcher = None EasySemanticChecker = None EasyStatChecker = None EasyExactSemanticSearcher = None EasyFuzzSemanticSearcher = None EasyStatSearcher = None + SeqCombinedSearcher = None logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!") else: from .checker import EasySemanticChecker, EasyStatChecker - from .searcher import EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, EasySearcher, EasyStatSearcher + from .searcher import EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, EasyStatSearcher, SeqCombinedSearcher __all__ = [ "EasyOrganizer", @@ -22,6 +22,6 @@ "EasyStatChecker", "EasyExactSemanticSearcher", "EasyFuzzSemanticSearcher", - "EasySearcher", "EasyStatSearcher", + "SeqCombinedSearcher", ] diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 3b22c9bb..568ce346 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -9,6 +9,8 @@ from ..utils import parse_specification_type from ...config import C from ...logger import get_module_logger +from ...specification import LLMGeneralCapabilitySpecification +from ...specification.system.llm_general_capability_spec.config import test_benchmark_configs logger = get_module_logger("easy_checker", "INFO") @@ -18,11 +20,16 @@ class EasySemanticChecker(BaseChecker): def check_semantic_spec(semantic_spec): try: for key in C["semantic_specs"]: + if C["semantic_specs"][key]["Type"] == "Optional": + if key not in semantic_spec: + continue + pass + value = semantic_spec[key]["Values"] valid_type = C["semantic_specs"][key]["Type"] assert semantic_spec[key]["Type"] == valid_type, f"{key} type mismatch" - if valid_type == "Class": + if valid_type == "Class" or valid_type == "Optional": valid_list = C["semantic_specs"][key]["Values"] assert len(value) == 1, f"{key} must be unique" assert value[0] in valid_list, f"{key} must be in {valid_list}" @@ -44,6 +51,36 @@ def check_semantic_spec(semantic_spec): assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" assert isinstance(v, str), "Description must be string" + assert semantic_spec["Task"]["Values"][0] in [ + "Classification", + "Regression", + "Feature Extraction", + "Others", + ] + + assert semantic_spec["Model"]["Values"][0] == "Others" + + if semantic_spec["Data"]["Values"][0] == "Image": + assert semantic_spec["Task"]["Values"][0] in [ + "Classification", + "Regression", + "Feature Extraction", + "Segmentation", + "Object Detection", + "Others", + ] + + assert semantic_spec["Model"]["Values"][0] == "Others" + + if semantic_spec["Data"]["Values"][0] == "Text": + assert semantic_spec["Task"]["Values"][0] in [ + "Classification", + "Regression", + "Feature Extraction", + "Text Generation", + "Others", + ] + if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression"]: assert semantic_spec["Output"] is not None, "Lack of output semantics" dim = semantic_spec["Output"]["Dimension"] @@ -106,6 +143,26 @@ def __call__(self, learnware): logger.warning(message) return self.INVALID_LEARNWARE, message + # check llm base model learnware general capability + if ( + semantic_spec["Data"]["Values"] == ["Text"] + and semantic_spec["Task"]["Values"] == ["Text Generation"] + and semantic_spec["Model"]["Values"] == ["Base Model"] + ): + try: + general_capability_spec = LLMGeneralCapabilitySpecification() + general_capability_spec.generate_stat_spec_from_system( + learnware=learnware, benchmark_configs=test_benchmark_configs + ) + learnware.update_stat_spec(general_capability_spec.type, general_capability_spec) + except Exception: + message = ( + f"The learnware [{learnware.id}] llm base model general capability evaluation is not available!" + ) + logger.warning(message) + message += "\r\n" + traceback.format_exc() + return self.INVALID_LEARNWARE, message + # Check statistical specification spec_type = parse_specification_type(learnware.get_specification().stat_spec) if spec_type is None: @@ -114,12 +171,15 @@ def __call__(self, learnware): return self.INVALID_LEARNWARE, message # Check if statistical specification is computable in dist() - stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) - distance = float(stat_spec.dist(stat_spec)) - if not np.isfinite(distance): - message = f"The distance between statistical specifications is not finite, where distance={distance}" - logger.warning(message) - return self.INVALID_LEARNWARE, message + if spec_type != "LLMGeneralCapabilitySpecification": + stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) + distance = float(stat_spec.dist(stat_spec)) + if not np.isfinite(distance): + message = ( + f"The distance between statistical specifications is not finite, where distance={distance}" + ) + logger.warning(message) + return self.INVALID_LEARNWARE, message if spec_type == "RKMETableSpecification": if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): @@ -133,8 +193,16 @@ def __call__(self, learnware): return self.INVALID_LEARNWARE, message inputs = np.random.randn(10, *input_shape) - elif spec_type == "RKMETextSpecification": - inputs = EasyStatChecker._generate_random_text_list(10) + elif spec_type in [ + "RKMETextSpecification", + "GenerativeModelSpecification", + "LLMGeneralCapabilitySpecification", + ]: + if semantic_spec["Model"]["Values"][0] != "Others": + len_ = random.randint(10, 1000) + inputs = EasyStatChecker._generate_random_text_list(10, "en", len_, len_) + else: + inputs = EasyStatChecker._generate_random_text_list(10) elif spec_type == "RKMEImageSpecification": if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): @@ -150,14 +218,14 @@ def __call__(self, learnware): try: outputs = learnware.predict(inputs) except Exception: - message = f"The learnware {learnware.id} prediction is not avaliable!" + message = f"The learnware [{learnware.id}] prediction is not available!" logger.warning(message) message += "\r\n" + traceback.format_exc() return self.INVALID_LEARNWARE, message # Check length of input and output if len(inputs) != len(outputs): - message = f"The learnware {learnware.id} output length must be equal to input length!" + message = f"The learnware [{learnware.id}] output length must be equal to input length!" logger.warning(message) return self.INVALID_LEARNWARE, message @@ -170,7 +238,7 @@ def __call__(self, learnware): if isinstance(outputs, torch.Tensor): outputs = outputs.detach().cpu().numpy() if not isinstance(outputs, np.ndarray): - message = f"The learnware {learnware.id} output must be np.ndarray or torch.Tensor!" + message = f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor!" logger.warning(message) return self.INVALID_LEARNWARE, message diff --git a/learnware/market/easy/searcher.py b/learnware/market/easy/searcher.py index dcfb7dfa..9847cb1c 100644 --- a/learnware/market/easy/searcher.py +++ b/learnware/market/easy/searcher.py @@ -5,8 +5,15 @@ import torch from rapidfuzz import fuzz -from .organizer import EasyOrganizer -from ..base import BaseSearcher, BaseUserInfo, MultipleSearchItem, SearchResults, SingleSearchItem +from ..base import ( + AtomicSearcher, + BaseOrganizer, + BaseSearcher, + BaseUserInfo, + MultipleSearchItem, + SearchResults, + SingleSearchItem, +) from ..utils import parse_specification_type from ...learnware import Learnware from ...logger import get_module_logger @@ -15,7 +22,13 @@ logger = get_module_logger("easy_seacher") -class EasyExactSemanticSearcher(BaseSearcher): +class EasyExactSemanticSearcher(AtomicSearcher): + def is_applicable_learnware(self, learnware: Learnware) -> bool: + return True + + def is_applicable_user(self, user_info: BaseUserInfo) -> bool: + return True + def _learnware_id_search(self, learnware_id: str, learnware_list: List[Learnware]) -> List[Learnware]: match_learnwares = [] for learnware in learnware_list: @@ -78,7 +91,13 @@ def __call__(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in match_learnwares]) -class EasyFuzzSemanticSearcher(BaseSearcher): +class EasyFuzzSemanticSearcher(AtomicSearcher): + def is_applicable_learnware(self, learnware: Learnware) -> bool: + return True + + def is_applicable_user(self, user_info: BaseUserInfo) -> bool: + return True + def _learnware_id_search(self, learnware_id: str, learnware_list: List[Learnware]) -> List[Learnware]: match_learnwares = [] for learnware in learnware_list: @@ -103,16 +122,24 @@ def _match_semantic_spec_tag(self, semantic_spec1, semantic_spec2) -> bool: """ for key in semantic_spec1.keys(): v1 = semantic_spec1[key].get("Values", "") - if key not in semantic_spec2 or len(v1) == 0: + if len(v1) == 0: continue + if key not in semantic_spec2: + if "Others" in v1: + # v1 contains "Others" and key not in semantic_spec2 + continue + else: + # user input contains some key that is not in database + return False + v2 = semantic_spec2[key].get("Values", "") if key not in ("Name", "Description"): if len(v2) == 0: # user input contains some key that is not in database return False - if semantic_spec1[key]["Type"] == "Class": + if semantic_spec1[key]["Type"] in ("Class", "Optional"): if isinstance(v2, list): v2 = v2[0] if v2 not in v1: @@ -203,7 +230,22 @@ def __call__( return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in final_result]) -class EasyStatSearcher(BaseSearcher): +class EasyStatSearcher(AtomicSearcher): + SPEC_TYPES = ["RKMETableSpecification", "RKMEImageSpecification", "RKMETextSpecification"] + + def is_applicable_learnware(self, learnware: Learnware) -> bool: + return any(spec_type in learnware.specification.stat_spec for spec_type in self.SPEC_TYPES) + + def is_applicable_user(self, user_info: BaseUserInfo) -> bool: + for spec_type in self.SPEC_TYPES: + if spec_type in user_info.stat_info: + user_rkme = user_info.stat_info[spec_type] + + if np.isfinite(float(user_rkme.dist(user_rkme))): + return True + + return False + def _convert_dist_to_score( self, dist_list: List[float], dist_ratio: float = 0.1, min_score: float = 0.92, improve_score: float = 0.7 ) -> List[float]: @@ -586,13 +628,9 @@ def __call__( max_search_num: int = 5, search_method: str = "greedy", ) -> SearchResults: - self.stat_spec_type = parse_specification_type(stat_specs=user_info.stat_info) - if self.stat_spec_type is None: - raise KeyError("No supported stat specification is given in the user info") + self.stat_spec_type = parse_specification_type(stat_specs=user_info.stat_info, spec_list=self.SPEC_TYPES) user_rkme = user_info.stat_info[self.stat_spec_type] - if not np.isfinite(float(user_rkme.dist(user_rkme))): - raise ValueError("The distance between uploaded statistical specifications is not finite!") learnware_list = self._filter_by_rkme_spec_metadata(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") @@ -664,48 +702,62 @@ def __call__( return search_results -class EasySearcher(BaseSearcher): - def __init__(self, organizer: EasyOrganizer): - self.semantic_searcher = EasyFuzzSemanticSearcher(organizer) - self.stat_searcher = EasyStatSearcher(organizer) - super(EasySearcher, self).__init__(organizer) - - def reset(self, organizer): +class SeqCombinedSearcher(BaseSearcher): + def __init__( + self, + organizer: BaseOrganizer, + semantic_searcher_list: List[AtomicSearcher], + stat_searcher_list: List[AtomicSearcher], + ): + self.semantic_searcher_list = semantic_searcher_list + self.stat_searcher_list = stat_searcher_list + super(SeqCombinedSearcher, self).__init__(organizer) + + def reset(self, organizer: BaseOrganizer): self.learnware_organizer = organizer - self.semantic_searcher.reset(organizer) - self.stat_searcher.reset(organizer) + for searcher in self.semantic_searcher_list + self.stat_searcher_list: + searcher.reset(organizer) def __call__( self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" ) -> SearchResults: - """Search learnwares based on user_info from learnwares with check_status + """ + Search learnwares based on user_info, iterating over semantic and stat searchers to find applicable results. Parameters ---------- user_info : BaseUserInfo - user_info contains semantic_spec and stat_info - max_search_num : int - The maximum number of the returned learnwares + The user information for searching learnwares. + max_search_num : int, optional + The maximum number of the returned learnwares. check_status : int, optional - None: search from all learnwares - - Others: search from learnwares with check_status + - Others: search from learnwares with check_status. Returns ------- - Tuple[List[float], List[Learnware], float, List[Learnware]] - the first is the sorted list of rkme dist - the second is the sorted list of Learnware (single) by the rkme dist - the third is the score of Learnware (mixture) - the fourth is the list of Learnware (mixture), the size is search_num + SearchResults + The search results, including sorted lists of learnwares and associated scores. """ learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) - semantic_search_result = self.semantic_searcher(learnware_list, user_info) - learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] + for semantic_searcher in self.semantic_searcher_list: + if semantic_searcher.is_applicable_user(user_info): + filtered_learnware_list = [ + learnware for learnware in learnware_list if semantic_searcher.is_applicable_learnware(learnware) + ] + semantic_search_result = semantic_searcher(filtered_learnware_list, user_info) + learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] + break + if len(learnware_list) == 0: return SearchResults() - if parse_specification_type(stat_specs=user_info.stat_info) is not None: - return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) - else: - return semantic_search_result + for stat_searcher in self.stat_searcher_list: + if stat_searcher.is_applicable_user(user_info): + filtered_learnware_list = [ + learnware for learnware in learnware_list if stat_searcher.is_applicable_learnware(learnware) + ] + return stat_searcher(filtered_learnware_list, user_info, max_search_num, search_method) + + return semantic_search_result diff --git a/learnware/market/heterogeneous/__init__.py b/learnware/market/heterogeneous/__init__.py index 4162f1d2..3c228a24 100644 --- a/learnware/market/heterogeneous/__init__.py +++ b/learnware/market/heterogeneous/__init__.py @@ -5,10 +5,10 @@ if not is_torch_available(verbose=False): HeteroMapTableOrganizer = None - HeteroSearcher = None - logger.error("HeteroMapTableOrganizer and HeteroSearcher are not available because 'torch' is not installed!") + HeteroStatSearcher = None + logger.error("HeteroMapTableOrganizer and HeteroStatSearcher are not available because 'torch' is not installed!") else: from .organizer import HeteroMapTableOrganizer - from .searcher import HeteroSearcher + from .searcher import HeteroStatSearcher -__all__ = ["HeteroMapTableOrganizer", "HeteroSearcher"] +__all__ = ["HeteroMapTableOrganizer", "HeteroStatSearcher"] diff --git a/learnware/market/heterogeneous/organizer/__init__.py b/learnware/market/heterogeneous/organizer/__init__.py index 3dcbc6d5..e04719a8 100644 --- a/learnware/market/heterogeneous/organizer/__init__.py +++ b/learnware/market/heterogeneous/organizer/__init__.py @@ -89,11 +89,11 @@ def add_learnware( - str indicating model_id - int indicating the final learnware check_status """ - learnware_id, learnwere_status = super(HeteroMapTableOrganizer, self).add_learnware( + learnware_id, learnware_status = super(HeteroMapTableOrganizer, self).add_learnware( zip_path, semantic_spec, check_status, learnware_id ) - if learnwere_status == BaseChecker.USABLE_LEARNWARE and len(self._get_hetero_learnware_ids(learnware_id)): + if learnware_status == BaseChecker.USABLE_LEARNWARE and len(self._get_hetero_learnware_ids(learnware_id)): self._update_learnware_hetero_spec(learnware_id) if self.auto_update: @@ -115,7 +115,7 @@ def add_learnware( self.count_down = self.auto_update_limit - return learnware_id, learnwere_status + return learnware_id, learnware_status def delete_learnware(self, id: str) -> bool: """Delete learnware from heterogeneous learnware market. diff --git a/learnware/market/heterogeneous/searcher.py b/learnware/market/heterogeneous/searcher.py index 5a10ac0c..40710b6e 100644 --- a/learnware/market/heterogeneous/searcher.py +++ b/learnware/market/heterogeneous/searcher.py @@ -1,19 +1,36 @@ -from typing import Optional +from typing import List from .utils import is_hetero from ..base import BaseUserInfo, SearchResults -from ..easy import EasySearcher -from ..utils import parse_specification_type +from ..easy import EasyStatSearcher +from ...learnware import Learnware from ...logger import get_module_logger logger = get_module_logger("hetero_searcher") -class HeteroSearcher(EasySearcher): +class HeteroStatSearcher(EasyStatSearcher): + SPEC_TYPES = ["HeteroMapTableSpecification"] + + def is_applicable_learnware(self, learnware: Learnware) -> bool: + if not super(HeteroStatSearcher, self).is_applicable_learnware(learnware): + return False + + spec = learnware.get_specification() + return is_hetero(stat_specs=spec.get_stat_spec(), semantic_spec=spec.get_semantic_spec(), verbose=False) + + def is_applicable_user(self, user_info: BaseUserInfo) -> bool: + if not super(HeteroStatSearcher, self).is_applicable_user(user_info): + return False + + stat_specs = user_info.stat_info + semantic_spec = user_info.semantic_spec + return is_hetero(stat_specs=stat_specs, semantic_spec=semantic_spec, verbose=False) + def __call__( self, + learnware_list: List[Learnware], user_info: BaseUserInfo, - check_status: Optional[int] = None, max_search_num: int = 5, search_method: str = "greedy", ) -> SearchResults: @@ -38,17 +55,7 @@ def __call__( the third is the score of Learnware (mixture) the fourth is the list of Learnware (mixture), the size is search_num """ - learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) - semantic_search_result = self.semantic_searcher(learnware_list, user_info) - - learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] - if len(learnware_list) == 0: - return SearchResults() - - if parse_specification_type(stat_specs=user_info.stat_info) is not None: - if is_hetero(stat_specs=user_info.stat_info, semantic_spec=user_info.semantic_spec): - user_hetero_spec = self.learnware_organizer.generate_hetero_map_spec(user_info) - user_info.update_stat_info(user_hetero_spec.type, user_hetero_spec) - return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) - else: - return semantic_search_result + user_hetero_spec = self.learnware_organizer.generate_hetero_map_spec(user_info) + user_info.update_stat_info(user_hetero_spec.type, user_hetero_spec) + + return super().__call__(learnware_list, user_info, max_search_num, search_method) diff --git a/learnware/market/heterogeneous/utils.py b/learnware/market/heterogeneous/utils.py index 860159e3..67b7fcbe 100644 --- a/learnware/market/heterogeneous/utils.py +++ b/learnware/market/heterogeneous/utils.py @@ -32,9 +32,9 @@ def is_hetero(stat_specs: dict, semantic_spec: dict, verbose=True) -> bool: semantic_input_description = semantic_spec["Input"] semantic_description_dim = int(semantic_input_description["Dimension"]) - semantic_decription_feature_num = len(semantic_input_description["Description"]) + semantic_description_feature_num = len(semantic_input_description["Description"]) - if semantic_decription_feature_num <= 0: + if semantic_description_feature_num <= 0: if verbose: logger.warning("At least one of Input.Description in semantic spec should be provides.") return False diff --git a/learnware/market/llm/__init__.py b/learnware/market/llm/__init__.py new file mode 100644 index 00000000..ecec1109 --- /dev/null +++ b/learnware/market/llm/__init__.py @@ -0,0 +1,14 @@ +from ...logger import get_module_logger +from ...utils import is_torch_available + +logger = get_module_logger("market_llm") + +if not is_torch_available(verbose=False): + LLMEasyOrganizer = None + LLMStatSearcher = None + logger.error("LLMStatSearcher and LLMEasyOrganizer are not available because 'torch' is not installed!") +else: + from .organizer import LLMEasyOrganizer + from .searcher import LLMStatSearcher + +__all__ = ["LLMEasyOrganizer", "LLMStatSearcher"] diff --git a/learnware/market/llm/organizer.py b/learnware/market/llm/organizer.py new file mode 100644 index 00000000..c44c3bb9 --- /dev/null +++ b/learnware/market/llm/organizer.py @@ -0,0 +1,104 @@ +import os +import tempfile +import traceback +import zipfile +from shutil import copyfile +from typing import List, Union + +from ..heterogeneous import HeteroMapTableOrganizer +from ...config import C +from ...logger import get_module_logger +from ...specification import LLMGeneralCapabilitySpecification +from ...utils import read_yaml_to_dict, save_dict_to_yaml + +logger = get_module_logger("llm_easy_organizer") + + +class LLMEasyOrganizer(HeteroMapTableOrganizer): + def _update_learnware_general_capability_spec(self, ids: Union[str, List[str]]): + """Update learnware by ids, attempting to generate LLMGeneralCapabilitySpecification for them. + + Parameters + ---------- + ids : Union[str, List[str]] + Give a id or a list of ids + str: id of target learnware + List[str]: A list of ids of target learnwares + """ + if isinstance(ids, str): + ids = [ids] + + for idx in ids: + try: + general_capability_spec = LLMGeneralCapabilitySpecification() + general_capability_spec.generate_stat_spec_from_system(learnware=self.learnware_list[idx]) + general_capability_spec_config = { + "module_path": "learnware.specification", + "class_name": general_capability_spec.type, + "file_name": "general_capability_spec.json", + "kwargs": {}, + } + + zip_path = self.learnware_zip_list[idx] + folder_dir = self.learnware_folder_list[idx] + self.learnware_list[idx].update_stat_spec(general_capability_spec.type, general_capability_spec) + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + # update yaml file + with zipfile.ZipFile(zip_path, "r") as z_file: + z_file.extract(C.learnware_folder_config["yaml_file"], tempdir) + + learnware_yaml_path = os.path.join(tempdir, C.learnware_folder_config["yaml_file"]) + yaml_config = read_yaml_to_dict(learnware_yaml_path) + if "stat_specifications" in yaml_config: + yaml_config["stat_specifications"].append(general_capability_spec_config) + else: + yaml_config["stat_specifications"] = [general_capability_spec_config] + pass + save_dict_to_yaml(yaml_config, learnware_yaml_path) + + with zipfile.ZipFile(zip_path, "a") as z_file: + z_file.write(learnware_yaml_path, C.learnware_folder_config["yaml_file"]) + + # save general capability specification + stat_spec_path = os.path.join(tempdir, general_capability_spec_config["file_name"]) + general_capability_spec.save(stat_spec_path) + with zipfile.ZipFile(zip_path, "a") as z_file: + z_file.write(stat_spec_path, general_capability_spec_config["file_name"]) + + # update learnware folder + copyfile(learnware_yaml_path, os.path.join(folder_dir, C.learnware_folder_config["yaml_file"])) + copyfile(stat_spec_path, os.path.join(folder_dir, general_capability_spec_config["file_name"])) + + except Exception as err: + traceback.print_exc() + logger.warning(f"Learnware {idx} generate LLMGeneralCapabilitySpecification failed!") + + def _get_llm_base_model_learnware_ids(self, ids: Union[str, List[str]]) -> List[str]: + """Get learnware ids that corresponding learnware contains a llm base model. + + Parameters + ---------- + ids : Union[str, List[str]] + Give a id or a list of ids + str: id of target learnware + List[str]: A list of ids of target learnwares + + Returns + ------- + List[str] + Learnware ids + """ + if isinstance(ids, str): + ids = [ids] + + ret = [] + for idx in ids: + semantic_spec = self.learnware_list[idx].get_specification().get_semantic_spec() + if ( + semantic_spec["Data"]["Values"] == ["Text"] + and semantic_spec["Task"]["Values"] == ["Text Generation"] + and semantic_spec["Model"]["Values"] == ["Base Model"] + ): + ret.append(idx) + return ret diff --git a/learnware/market/llm/searcher.py b/learnware/market/llm/searcher.py new file mode 100644 index 00000000..83015c6d --- /dev/null +++ b/learnware/market/llm/searcher.py @@ -0,0 +1,129 @@ +from typing import List, Tuple, Union + +import numpy as np +import torch +from torch.nn.functional import softmax + +from learnware.learnware.base import Learnware +from learnware.specification.base import Specification + +from ..base import BaseUserInfo, SearchResults, SingleSearchItem +from ..easy import EasyStatSearcher +from ..utils import parse_specification_type +from ...logger import get_module_logger + +logger = get_module_logger("llm_searcher") + + +class LLMStatSearcher(EasyStatSearcher): + SPEC_TYPES = ["GenerativeModelSpecification"] + + def is_applicable_user(self, user_info: BaseUserInfo, verbose: bool = True) -> bool: + stat_specs = user_info.stat_info + semantic_spec = user_info.semantic_spec + try: + if "GenerativeModelSpecification" not in stat_specs: + if verbose: + logger.warning("GenerativeModelSpecification is not provided in stat_info.") + return False + + semantic_data_type = semantic_spec["Data"]["Values"] + if len(semantic_data_type) > 0 and semantic_data_type != ["Text"]: + logger.warning("User doesn't provide correct data type, it must be Text.") + return False + + semantic_task_type = semantic_spec["Task"]["Values"] + if len(semantic_task_type) > 0 and semantic_task_type != ["Text Generation"]: + logger.warning("User doesn't provide correct task type, it must be Text Generation.") + return False + + return True + except Exception as err: + if verbose: + logger.warning("Invalid llm search information provided.") + return False + + def __call__( + self, + learnware_list: List[Learnware], + user_info: BaseUserInfo, + max_search_num: int = 5, + search_method: str = "greedy", + ) -> SearchResults: + self.stat_spec_type = parse_specification_type(stat_specs=user_info.stat_info, spec_list=self.SPEC_TYPES) + + user_spec = user_info.stat_info[self.stat_spec_type] + + sorted_metric_list, single_learnware_list = self._search_by_taskvector_spec_single(learnware_list, user_spec) + if len(single_learnware_list) == 0: + return SearchResults() + + if self.stat_spec_type == "GenerativeModelSpecification": + sorted_score_list = self._convert_similarity_to_score(sorted_metric_list) + else: + sorted_score_list = self._convert_dist_to_score(sorted_metric_list) + + logger.info(f"After search by user spec, learnware_list length is {len(learnware_list)}") + + if len(single_learnware_list) == 1 and sorted_score_list[0] < 0.6: + sorted_score_list[0] = 0.6 + + search_results = SearchResults() + search_results.update_single_results( + [ + SingleSearchItem(learnware=_learnware, score=_score) + for _score, _learnware in zip(sorted_score_list, single_learnware_list) + ] + ) + + return search_results + + def _search_by_taskvector_spec_single( + self, + learnware_list: List[Learnware], + user_spec: Union[Specification], + stat_spec_type: str = "GenerativeModelSpecification", + ) -> Tuple[List[float], List[Learnware]]: + """Calculate the distances between learnwares in the given learnware_list and user_spec + + Parameters + ---------- + learnware_list : List[Learnware] + The list of learnwares whose mixture approximates the user's rkme + user_spec : GenerativeModelSpecification + user Task Vector statistical specification + stat_spec_type : str + GenerativeModelSpecification by default. + + Returns + ------- + Tuple[List[float], List[Learnware]] + the first is the list of cosine similarity + the second is the list of Learnware + both lists are sorted by cosine similarity + """ + spec_list = [learnware.specification.get_stat_spec_by_name(stat_spec_type) for learnware in learnware_list] + filtered_idx_list, similarity_list = [], [] + for idx, s in enumerate(spec_list): + user_spec.task_vector = user_spec.task_vector.to(s.task_vector.device) + similarity = float(s.similarity(user_spec)) + if np.isfinite(similarity): + similarity_list.append(similarity) + filtered_idx_list.append(idx) + else: + logger.warning( + f"The distance between user_spec and learnware_spec (id: {learnware_list[idx].id}) is not finite, where similarity is {similarity}" + ) + + sorted_idx_list = list(reversed(sorted(range(len(similarity_list)), key=lambda k: similarity_list[k]))) + sorted_dist_list = [similarity_list[idx] for idx in sorted_idx_list] + sorted_learnware_list = [learnware_list[filtered_idx_list[idx]] for idx in sorted_idx_list] + + return sorted_dist_list, sorted_learnware_list + + def _convert_similarity_to_score(self, sorted_similarity_list, temperature=0.1): + sorted_similarity = torch.asarray(sorted_similarity_list) + sorted_similarity = torch.stack([sorted_similarity, torch.zeros_like(sorted_similarity)]) + + scores = softmax(sorted_similarity / temperature, dim=0)[0].tolist() + return scores * 100 diff --git a/learnware/market/module.py b/learnware/market/module.py index cdc13e78..ba4d9acf 100644 --- a/learnware/market/module.py +++ b/learnware/market/module.py @@ -1,7 +1,15 @@ from .base import LearnwareMarket from .classes import CondaChecker -from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker -from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher +from .easy import ( + EasyFuzzSemanticSearcher, + EasyOrganizer, + EasySemanticChecker, + EasyStatChecker, + EasyStatSearcher, + SeqCombinedSearcher, +) +from .heterogeneous import HeteroMapTableOrganizer, HeteroStatSearcher +from .llm import LLMEasyOrganizer, LLMStatSearcher def get_market_component( @@ -13,19 +21,37 @@ def get_market_component( if name == "easy": easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild) - easy_searcher = EasySearcher(organizer=easy_organizer) + + semantic_searcher_list = [EasyFuzzSemanticSearcher(easy_organizer)] + stat_searcher_list = [EasyStatSearcher(easy_organizer)] + easy_searcher = SeqCombinedSearcher( + organizer=easy_organizer, + semantic_searcher_list=semantic_searcher_list, + stat_searcher_list=stat_searcher_list, + ) + easy_checker_list = [ EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()), ] + market_component = { "organizer": easy_organizer, "searcher": easy_searcher, "checker_list": easy_checker_list, } + elif name == "hetero": hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs) - hetero_searcher = HeteroSearcher(organizer=hetero_organizer) + + semantic_searcher_list = [EasyFuzzSemanticSearcher(hetero_organizer)] + stat_searcher_list = [HeteroStatSearcher(hetero_organizer), EasyStatSearcher(hetero_organizer)] + hetero_searcher = SeqCombinedSearcher( + organizer=hetero_organizer, + semantic_searcher_list=semantic_searcher_list, + stat_searcher_list=stat_searcher_list, + ) + hetero_checker_list = [ EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()), @@ -36,6 +62,34 @@ def get_market_component( "searcher": hetero_searcher, "checker_list": hetero_checker_list, } + + elif name == "llm": + llm_organizer = LLMEasyOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs) + + semantic_searcher_list = [EasyFuzzSemanticSearcher(llm_organizer)] + stat_searcher_list = [ + LLMStatSearcher(llm_organizer), + HeteroStatSearcher(llm_organizer), + EasyStatSearcher(llm_organizer), + ] + + llm_searcher = SeqCombinedSearcher( + organizer=llm_organizer, + semantic_searcher_list=semantic_searcher_list, + stat_searcher_list=stat_searcher_list, + ) + + llm_checker_list = [ + EasySemanticChecker(), + EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()), + ] + + market_component = { + "organizer": llm_organizer, + "searcher": llm_searcher, + "checker_list": llm_checker_list, + } + else: raise ValueError(f"name {name} is not supported for market") diff --git a/learnware/market/utils.py b/learnware/market/utils.py index 79411ba3..c13e2049 100644 --- a/learnware/market/utils.py +++ b/learnware/market/utils.py @@ -3,8 +3,10 @@ def parse_specification_type( spec_list=[ "HeteroMapTableSpecification", "RKMETableSpecification", + "GenerativeModelSpecification", "RKMETextSpecification", "RKMEImageSpecification", + "LLMGeneralCapabilitySpecification", ], ): for spec in spec_list: diff --git a/learnware/model/__init__.py b/learnware/model/__init__.py index d237fd17..a1dd4259 100644 --- a/learnware/model/__init__.py +++ b/learnware/model/__init__.py @@ -1,3 +1,9 @@ from .base import BaseModel +from ..utils import is_torch_available -__all__ = ["BaseModel"] +if not is_torch_available(verbose=False): + TorchModel = None +else: + from .torch_model import TorchModel + +__all__ = ["BaseModel", "TorchModel"] diff --git a/learnware/model/base.py b/learnware/model/base.py index 74e3860c..6e0fa19b 100644 --- a/learnware/model/base.py +++ b/learnware/model/base.py @@ -45,3 +45,11 @@ def finetune(self, X: np.ndarray, y: np.ndarray): labels for finetuning """ pass + + def get_model(self): + """Get the nn.Module object + + Returns: + nn.Module: The model object, such as a PreTrainedModel from the transformers library. + """ + pass diff --git a/learnware/model/torch_model.py b/learnware/model/torch_model.py new file mode 100644 index 00000000..798c06dc --- /dev/null +++ b/learnware/model/torch_model.py @@ -0,0 +1,30 @@ +import numpy as np +from torch import nn + + +class TorchModel: + def __init__( + self, + model: nn.Module, + input_shape: tuple, + output_shape: tuple, + ): + self._model = model + self.input_shape = input_shape + self.output_shape = output_shape + + @property + def nn_model(self) -> nn.Module: + """ + fetch the inner model + """ + return self._model + + def predict(self, X: np.ndarray) -> np.ndarray: + pass + + def fit(self, X: np.ndarray, y: np.ndarray): + pass + + def finetune(self, X: np.ndarray, y: np.ndarray): + pass diff --git a/learnware/reuse/ensemble_pruning.py b/learnware/reuse/ensemble_pruning.py index 3ad0e950..7afc7f59 100644 --- a/learnware/reuse/ensemble_pruning.py +++ b/learnware/reuse/ensemble_pruning.py @@ -15,7 +15,7 @@ class EnsemblePruningReuser(BaseReuser): """ Baseline Multiple Learnware Reuser uing Marign Distribution guided multi-objective evolutionary Ensemble Pruning (MDEP) Method. - References: [1] Yu-Chang Wu, Yi-Xiao He, Chao Qian, and Zhi-Hua Zhou. Multi-objective Evolutionary Ensemble Pruning Guided by Margin Distribution. In: Proceedings of the 17th International Conference on Parallel Problem Solving from Nature (PPSN'22), Dortmund, Germany, 2022. + References: [1] Yu-Chang Wu, Yi-Xiao He, Chao Qian, and Zhi-Hua Zhou. Multi-objective evolutionary ensemble pruning guided by margin distribution. In: Proceedings of the 17th International Conference on Parallel Problem Solving from Nature (PPSN'22), 2022, pp.427-441. """ def __init__(self, learnware_list: List[Learnware] = None, mode: str = "classification"): diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 82246ff3..64bfe7d1 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,5 +1,6 @@ from .base import BaseStatSpecification, Specification from .regular import ( + GenerativeModelSpecification, RegularStatSpecification, RKMEImageSpecification, RKMEStatSpecification, @@ -7,7 +8,7 @@ RKMETextSpecification, rkme_solve_qp, ) -from .system import HeteroMapTableSpecification +from .system import HeteroMapTableSpecification, LLMGeneralCapabilitySpecification from ..utils import is_torch_available if not is_torch_available(verbose=False): @@ -15,9 +16,11 @@ generate_rkme_table_spec = None generate_rkme_image_spec = None generate_rkme_text_spec = None + generate_generative_model_spec = None generate_semantic_spec = None else: from .module import ( + generate_generative_model_spec, generate_rkme_image_spec, generate_rkme_table_spec, generate_rkme_text_spec, @@ -33,11 +36,14 @@ "RKMEStatSpecification", "RKMETableSpecification", "RKMETextSpecification", + "GenerativeModelSpecification", "HeteroMapTableSpecification", + "LLMGeneralCapabilitySpecification", "rkme_solve_qp", "generate_rkme_image_spec", "generate_rkme_table_spec", "generate_rkme_text_spec", + "generate_generative_model_spec", "generate_semantic_spec", "generate_stat_spec", ] diff --git a/learnware/specification/module.py b/learnware/specification/module.py index 9ad3d8a3..ffd0d2fb 100644 --- a/learnware/specification/module.py +++ b/learnware/specification/module.py @@ -3,8 +3,9 @@ import numpy as np import pandas as pd import torch +from datasets import Dataset -from .regular import RKMEImageSpecification, RKMETableSpecification, RKMETextSpecification +from .regular import GenerativeModelSpecification, RKMEImageSpecification, RKMETableSpecification, RKMETextSpecification from .utils import convert_to_numpy from ..config import C @@ -175,6 +176,22 @@ def generate_rkme_text_spec( return rkme_text_spec +def generate_generative_model_spec( + dataset: Optional[Dataset] = None, dataset_text_field="text", X: List[str] = None, verbose: bool = True, **kwargs +) -> GenerativeModelSpecification: + # Check input type + if X is not None and (not isinstance(X, list) or not all(isinstance(item, str) for item in X)): + raise TypeError("Input data must be a list of strings.") + + # Generate generative model spec + task_vector_spec = GenerativeModelSpecification() + task_vector_spec.generate_stat_spec_from_data( + dataset=dataset, dataset_text_field=dataset_text_field, X=X, verbose=verbose, **kwargs + ) + + return task_vector_spec + + def generate_stat_spec( type: str, X: Union[np.ndarray, pd.DataFrame, torch.Tensor, List[str]], *args, **kwargs ) -> Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification]: @@ -211,6 +228,7 @@ def generate_semantic_spec( description: Optional[str] = None, data_type: Optional[str] = None, task_type: Optional[str] = None, + model_type: Optional[str] = None, library_type: Optional[str] = None, scenarios: Optional[Union[str, List[str]]] = None, license: Optional[Union[str, List[str]]] = None, @@ -220,6 +238,10 @@ def generate_semantic_spec( semantic_specification = dict() semantic_specification["Data"] = {"Type": "Class", "Values": [data_type] if data_type is not None else []} semantic_specification["Task"] = {"Type": "Class", "Values": [task_type] if task_type is not None else []} + semantic_specification["Model"] = { + "Type": "Optional", + "Values": [model_type] if model_type is not None else ["Others"], + } semantic_specification["Library"] = { "Type": "Class", "Values": [library_type] if library_type is not None else [], diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index 51c79219..3544b566 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,7 +1,7 @@ from .base import RegularStatSpecification from .image import RKMEImageSpecification from .table import RKMEStatSpecification, RKMETableSpecification, rkme_solve_qp -from .text import RKMETextSpecification +from .text import GenerativeModelSpecification, RKMETextSpecification __all__ = [ "RegularStatSpecification", @@ -10,4 +10,5 @@ "RKMETableSpecification", "rkme_solve_qp", "RKMETextSpecification", + "GenerativeModelSpecification", ] diff --git a/learnware/specification/regular/base.py b/learnware/specification/regular/base.py index 1960f0d9..8159d12e 100644 --- a/learnware/specification/regular/base.py +++ b/learnware/specification/regular/base.py @@ -1,5 +1,7 @@ from __future__ import annotations +from torch.nn.functional import cosine_similarity + from ..base import BaseStatSpecification @@ -13,3 +15,21 @@ def generate_stat_spec_from_data(self, **kwargs): - kwargs also can include hyperparameters of specific method for specifaction generation """ raise NotImplementedError("generate_stat_spec_from_data is not implemented") + + +class TaskVectorSpecification(RegularStatSpecification): + @property + def task_vector(self): + raise NotImplementedError + + def similarity(self, other: TaskVectorSpecification) -> float: + """Compute cosine similarity between two task vectors.""" + v1, v2 = self.task_vector, other.task_vector + + return cosine_similarity(v1, v2, dim=0) + + def dist(self, other: BaseStatSpecification): + v1, v2 = self.task_vector, other.task_vector + + similarity = cosine_similarity(v1, v2, dim=0) # [-1, 1] + return (-similarity + 1) / 2 diff --git a/learnware/specification/regular/image/cnn_gp.py b/learnware/specification/regular/image/cnn_gp.py index 85d8cfd5..a429e1b2 100644 --- a/learnware/specification/regular/image/cnn_gp.py +++ b/learnware/specification/regular/image/cnn_gp.py @@ -11,7 +11,7 @@ Github Repository: https://github.com/cambridge-mlg/cnn-gp -References: [1] A. Garriga-Alonso, L. Aitchison, and C. E. Rasmussen. Deep Convolutional Networks as shallow Gaussian Processes. In: International Conference on Learning Representations (ICLR'19), 2019. +References: [1] Adrià Garriga-Alonso, Laurence Aitchison, and Carl Edward Rasmussen. Deep convolutional networks as shallow gaussian processes. In: International Conference on Learning Representations (ICLR'19), 2019. """ diff --git a/learnware/specification/regular/text/__init__.py b/learnware/specification/regular/text/__init__.py index 18f2c2dd..47d1fc16 100644 --- a/learnware/specification/regular/text/__init__.py +++ b/learnware/specification/regular/text/__init__.py @@ -5,8 +5,12 @@ if not is_torch_available(verbose=False): RKMETextSpecification = None - logger.error("RKMETextSpecification is not available because 'torch' is not installed!") + GenerativeModelSpecification = None + logger.error( + "RKMETextSpecification and GenerativeModelSpecification are not available because 'torch' is not installed!" + ) else: + from .generative import GenerativeModelSpecification from .rkme import RKMETextSpecification -__all__ = ["RKMETextSpecification"] +__all__ = ["RKMETextSpecification", "GenerativeModelSpecification"] diff --git a/learnware/specification/regular/text/generative.py b/learnware/specification/regular/text/generative.py new file mode 100644 index 00000000..e5eb7b91 --- /dev/null +++ b/learnware/specification/regular/text/generative.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import os +import random +import tempfile +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import trl +from datasets import Dataset +from peft import LoraConfig, PeftModel, get_peft_model +from torch import nn +from transformers import PreTrainedModel, Qwen2ForCausalLM, Qwen2Tokenizer, TrainingArguments +from trl import SFTConfig + +from ..base import TaskVectorSpecification +from ....logger import get_module_logger +from ....utils import allocate_cuda_idx, choose_device + +logger = get_module_logger("GenerativeModelSpecification", "INFO") + + +class GenerativeModelSpecification(TaskVectorSpecification): + """Task Vector Specification for Large Language Model""" + + def __init__( + self, + cuda_idx: int = None, + attn_implementation: str = "eager", + per_device_train_batch_size: int = 2, + gradient_accumulation_steps: int = 1, + max_seq_length: int = 2048, + **kwargs, + ): + """Initializing Task Vector Specification's parameters. + + Parameters + ---------- + cuda_idx : int, optional + A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. None indicates automatically choose device + + attn_implementation : str, optional + The type of attention implementation to use. Default is 'eager'. + + per_device_train_batch_size : int, optional + The training batch size for each device. Default is 2. + + gradient_accumulation_steps : int, optional + The number of steps to accumulate gradients before an optimizer step. + Default is 1. + + max_seq_length : int, optional + The maximum sequence length for the model input. Default is 2048. + + **kwargs : dict + Additional keyword arguments. + """ + super(GenerativeModelSpecification, self).__init__(type=self.__class__.__name__) + + self._cuda_idx = allocate_cuda_idx() if cuda_idx is None else cuda_idx + self._device = choose_device(cuda_idx=self._cuda_idx) + + self._task_vector = None + + self.attn_implementation = attn_implementation + self.per_device_train_batch_size = per_device_train_batch_size + self.gradient_accumulation_steps = gradient_accumulation_steps + self.max_seq_length = max_seq_length + + self.__extra_args = { + "weight_decay_l1": 1.0, + "weight_decay_l2": 0.5, + "max_steps": 400, + "lr": 1e-5, + "max_grad_norm": 1.0, + "warmup_ratio": 0.03, + } + + @property + def task_vector(self): + if self._task_vector is None: + raise Exception("Call generate_stat_spec_from_data first!") + + return self._task_vector + + @task_vector.setter + def task_vector(self, value): + self._task_vector = value + + def generate_stat_spec_from_data( + self, + dataset: Optional[Dataset] = None, + dataset_text_field="text", + X: List[str] = None, + verbose: bool = True, + beimingwu=True, + **kwargs, + ): + """Initializing Task Vector Specification's parameters. + + Parameters + ---------- + + dataset_text_field : str, optional + Name of the text field of the dataset. Default is "text". + + """ + if dataset is None: + assert X is not None, "X and dataset cannot both be None." + dataset = Dataset.from_dict({dataset_text_field: X}) + + with tempfile.TemporaryDirectory() as temp_dir: + tokenizer, model = self._init_tokenizer_model(beimingwu) + trainer_config = self._trainer_config(temp_dir, dataset_text_field) + trainer = self._init_trainer(model, tokenizer, dataset, trainer_config) + + param_0 = [p.detach().clone() for n, p in trainer.model.named_parameters() if p.requires_grad] + trainer.train() + param_1 = [p.detach().clone() for n, p in trainer.model.named_parameters() if p.requires_grad] + + self._task_vector = torch.concatenate([(p1 - p0).reshape(-1) for p0, p1 in zip(param_0, param_1)]) + + def _init_tokenizer_model(self, beimingwu): + """ + Initialize foundational model (e.g. Qwen) used for task vector generation. + And, this method should not be overridden if the specification needs to be submitted to Beimingwu. + """ + if beimingwu: + from ....client import LearnwareClient + + client = LearnwareClient() + base_model_path = client.get_pretrained_path("00002890") + else: + base_model_path = "Qwen/Qwen2.5-0.5B" + + set_seed(3407) + tokenizer = Qwen2Tokenizer.from_pretrained(base_model_path) + model = Qwen2ForCausalLM.from_pretrained( + base_model_path, + attn_implementation=self.attn_implementation, + torch_dtype=torch.bfloat16, + ).to(self._device) + + if beimingwu: + client = LearnwareClient() + adapter_path = client.get_pretrained_path("00002891") + model = PeftModel.from_pretrained(model, adapter_path) + + for n, p in model.named_parameters(): + if "lora_B" in n: + p.requires_grad = True + else: + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.1, + bias="none", + task_type="CAUSAL_LM", + target_modules=["q_proj", "k_proj", "v_proj"], + ) + model = get_peft_model(model, peft_config) + + for n, p in model.named_parameters(): + if "lora_A" in n: + p.requires_grad = False + + return tokenizer, model + + def _init_trainer(self, model, tokenizer, train_dataset, args): + # TODO: set_seed(3407) + trainer = CustomSFTTrainer( + model=model, + train_dataset=train_dataset, + tokenizer=tokenizer, + weight_decay_l1=self.__extra_args["weight_decay_l1"], + args=args, + ) + # Work around trl package bug with multi-GPU parallelism + trainer.args._n_gpu = 1 + + return trainer + + def _trainer_config(self, temp_dir, dataset_text_field): + training_params = SFTConfig( + output_dir=temp_dir, + max_steps=self.__extra_args["max_steps"], + per_device_train_batch_size=self.per_device_train_batch_size, + gradient_accumulation_steps=self.gradient_accumulation_steps, + learning_rate=self.__extra_args["lr"], + weight_decay=self.__extra_args["weight_decay_l2"], + optim="adamw_torch", + eval_strategy="no", + save_strategy="no", + # fp16=True, + # bf16=True, + max_grad_norm=self.__extra_args["max_grad_norm"], + warmup_ratio=self.__extra_args["warmup_ratio"], + group_by_length=True, + lr_scheduler_type="cosine", + ddp_timeout=180000000, + dataset_text_field=dataset_text_field, + max_seq_length=self.max_seq_length, + dataloader_num_workers=16, + seed=3407, + ) + + return training_params + + def save(self, filepath: str): + torch.save({"type": self.type, "task_vector": self.task_vector.detach().cpu()}, filepath) + + def load(self, filepath: str): + state = torch.load(filepath, weights_only=True) + if state["type"] != self.type: + logger.warning("{} may not be consistent with this class {}.".format(state["type"], self.type)) + self._task_vector = state["task_vector"].to(self._device) + + +class CustomSFTTrainer(trl.SFTTrainer): + def __init__(self, weight_decay_l1=None, **kwargs): + super().__init__(**kwargs) + model: Union[PreTrainedModel, nn.Module] = kwargs["model"] + args: TrainingArguments = kwargs["args"] + + if hasattr(args, "weight_decay_l1") and (weight_decay_l1 is not None): + print("Warning! weight_decay_l1 is overwrited by key args.") + if weight_decay_l1 is not None: + self.weight_decay_l1 = weight_decay_l1 + elif hasattr(args, "weight_decay_l1"): + self.weight_decay_l1 = args.weight_decay_l1 + else: + assert False, "weight_decay_l1 shounld be given." + + self.parameters_l1_regularized = None + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", Dict[str, Any]] = None, # noqa: F821 + ignore_keys_for_eval: Optional[List[str]] = None, + **kwargs, + ): + self.parameters_l1_regularized = [ + (p, torch.nn.Parameter(p.clone().detach())) for n, p in self.model.named_parameters() if p.requires_grad + ] + + return super().train( + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + **kwargs, + ) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # implement custom logic here + default_loss, outputs = super().compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) + + if self.weight_decay_l1 > 0: + l1_norm = sum((torch.linalg.norm(p - p0, 1) for p, p0 in self.parameters_l1_regularized)) + # We mask lora_A after init. + l1_norm = self.weight_decay_l1 / len(self.parameters_l1_regularized) * l1_norm + loss = default_loss + l1_norm + else: + loss = default_loss + + return (loss, outputs) if return_outputs else loss + + +def set_seed(seed): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True diff --git a/learnware/specification/system/__init__.py b/learnware/specification/system/__init__.py index d89292a2..6472fa3c 100644 --- a/learnware/specification/system/__init__.py +++ b/learnware/specification/system/__init__.py @@ -6,8 +6,12 @@ if not is_torch_available(verbose=False): HeteroMapTableSpecification = None - logger.error("HeteroMapTableSpecification is not available because 'torch' is not installed!") + LLMGeneralCapabilitySpecification = None + logger.error( + "HeteroMapTableSpecification and LLMGeneralCapabilitySpecification are not available because 'torch' is not installed!" + ) else: from .hetero_table import HeteroMapTableSpecification + from .llm_general_capability_spec.spec import LLMGeneralCapabilitySpecification -__all__ = ["SystemStatSpecification", "HeteroMapTableSpecification"] +__all__ = ["SystemStatSpecification", "HeteroMapTableSpecification", "LLMGeneralCapabilitySpecification"] diff --git a/learnware/specification/system/llm_general_capability_spec/__init__.py b/learnware/specification/system/llm_general_capability_spec/__init__.py new file mode 100644 index 00000000..da1cc9ea --- /dev/null +++ b/learnware/specification/system/llm_general_capability_spec/__init__.py @@ -0,0 +1,12 @@ +from ....logger import get_module_logger +from ....utils import is_torch_available + +logger = get_module_logger("system_general_capability_spec") + +if not is_torch_available(verbose=False): + LLMGeneralCapabilitySpecification = None + logger.error("LLMGeneralCapabilitySpecification are not available because 'torch' is not installed!") +else: + from .spec import LLMGeneralCapabilitySpecification + +__all__ = ["LLMGeneralCapabilitySpecification"] diff --git a/learnware/specification/system/llm_general_capability_spec/config.py b/learnware/specification/system/llm_general_capability_spec/config.py new file mode 100644 index 00000000..5fee42a0 --- /dev/null +++ b/learnware/specification/system/llm_general_capability_spec/config.py @@ -0,0 +1,166 @@ +from typing import List + +import numpy as np + +from ....tests.benchmarks import LLMBenchmarkConfig + +# Score normalization functions, copied from the interactive notebook in https://huggingface.co/docs/leaderboards/open_llm_leaderboard/normalization + + +def normalize_within_range(value, lower_bound=0, higher_bound=1): + return (np.clip(value - lower_bound, 0, None)) / (higher_bound - lower_bound) * 100 + + +def compute_bbh_score(data): + bbh_subtasks = { + "sports_understanding": 2, + "tracking_shuffled_objects_three_objects": 3, + "navigate": 2, + "snarks": 2, + "date_understanding": 6, + "reasoning_about_colored_objects": 18, + "object_counting": 19, + "logical_deduction_seven_objects": 7, + "geometric_shapes": 11, + "web_of_lies": 2, + "movie_recommendation": 6, + "logical_deduction_five_objects": 5, + "salient_translation_error_detection": 6, + "disambiguation_qa": 3, + "temporal_sequences": 4, + "hyperbaton": 2, + "logical_deduction_three_objects": 3, + "causal_judgement": 2, + "formal_fallacies": 2, + "tracking_shuffled_objects_seven_objects": 7, + "ruin_names": 6, + "penguins_in_a_table": 5, + "boolean_expressions": 2, + "tracking_shuffled_objects_five_objects": 5, + } + # Normalize BBH subtasks scores + bbh_scores = [] + for subtask, num_choices in bbh_subtasks.items(): + subtask_key = f"leaderboard_bbh_{subtask}" + if subtask_key in data["results"]: + bbh_raw_score = data["results"][subtask_key]["acc_norm,none"] + lower_bound = 1 / num_choices + normalized_score = normalize_within_range(bbh_raw_score, lower_bound, 1.0) + bbh_scores.append(normalized_score) + + # Average BBH score + bbh_score = sum(bbh_scores) / len(bbh_scores) + return round(bbh_score, 2) + + +def compute_gpqa_score(data): + gpqa_subtasks = ["leaderboard_gpqa_diamond", "leaderboard_gpqa_extended", "leaderboard_gpqa_main"] + # Normalize GPQA scores + gpqa_raw_scores = [] + for subtask in gpqa_subtasks: + gpqa_raw_scores.append(data["results"][subtask]["acc_norm,none"]) + gpqa_raw_score = sum(gpqa_raw_scores) / len(gpqa_raw_scores) + gpqa_score = normalize_within_range(gpqa_raw_score, 0.25, 1.0) + return round(gpqa_score, 2) + + +def compute_ifeval_score(data): + # Compute IFEval + ifeval_inst_score = data["results"]["leaderboard_ifeval"]["inst_level_strict_acc,none"] * 100 + ifeval_prompt_score = data["results"]["leaderboard_ifeval"]["prompt_level_strict_acc,none"] * 100 + + # Average IFEval scores + ifeval_score = (ifeval_inst_score + ifeval_prompt_score) / 2 + return round(ifeval_score, 2) + + +def compute_math_score(data): + math_subtasks = [ + "leaderboard_math_algebra_hard", + "leaderboard_math_counting_and_prob_hard", + "leaderboard_math_geometry_hard", + "leaderboard_math_intermediate_algebra_hard", + "leaderboard_math_num_theory_hard", + "leaderboard_math_prealgebra_hard", + "leaderboard_math_precalculus_hard", + ] + # Calculate the MATH score + math_raw_scores = [] + for subtask in math_subtasks: + math_raw_scores.append(data["results"][subtask]["exact_match,none"]) + math_raw_score = sum(math_raw_scores) / len(math_raw_scores) + math_score = normalize_within_range(math_raw_score, 0, 1.0) + return round(math_score, 2) + + +def compute_mmlu_pro_score(data): + # Normalize MMLU PRO scores + mmlu_pro_raw_score = data["results"]["leaderboard_mmlu_pro"]["acc,none"] + mmlu_pro_score = normalize_within_range(mmlu_pro_raw_score, 0.1, 1.0) + return round(mmlu_pro_score, 2) + + +def compute_musr_score(data): + musr_subtasks = {"murder_mysteries": 2, "object_placements": 5, "team_allocation": 3} + # Normalize MUSR scores + musr_scores = [] + + for subtask, num_choices in musr_subtasks.items(): + musr_raw_score = data["results"][f"leaderboard_musr_{subtask}"]["acc_norm,none"] + lower_bound = 1 / num_choices + normalized_score = normalize_within_range(musr_raw_score, lower_bound, 1.0) + musr_scores.append(normalized_score) + + musr_score = sum(musr_scores) / len(musr_scores) + return round(musr_score, 2) + + +test_benchmark_configs: List[LLMBenchmarkConfig] = [ + LLMBenchmarkConfig( + name="mmlu_anatomy", + dataset_path="hails/mmlu_no_train", + validation_split="validation", + test_split="test", + eval_metric="acc", + ), +] + +general_capability_benchmark_configs: List[LLMBenchmarkConfig] = [ + LLMBenchmarkConfig( + name="leaderboard_bbh", + dataset_path="SaylorTwift/bbh", + test_split="test", + score_function=compute_bbh_score, + ), + LLMBenchmarkConfig( + name="leaderboard_gpqa", + dataset_path="Idavidrein/gpqa", + test_split="train", + score_function=compute_gpqa_score, + ), + LLMBenchmarkConfig( + name="leaderboard_ifeval", + dataset_path="wis-k/instruction-following-eval", + test_split="train", + score_function=compute_ifeval_score, + ), + LLMBenchmarkConfig( + name="leaderboard_math_hard", + dataset_path="lighteval/MATH-Hard", + train_split="train", + test_split="test", + score_function=compute_math_score, + ), + LLMBenchmarkConfig( + name="leaderboard_mmlu_pro", + dataset_path="TIGER-Lab/MMLU-Pro", + validation_split="validation", + test_split="test", + score_function=compute_mmlu_pro_score, + ), + LLMBenchmarkConfig( + name="leaderboard_musr", + dataset_path="TAUR-Lab/MuSR", + score_function=compute_musr_score, + ), +] diff --git a/learnware/specification/system/llm_general_capability_spec/spec.py b/learnware/specification/system/llm_general_capability_spec/spec.py new file mode 100644 index 00000000..5dad78df --- /dev/null +++ b/learnware/specification/system/llm_general_capability_spec/spec.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import codecs +import json +import os +import traceback +from typing import Dict, List, Optional + +import lm_eval +from lm_eval.models.huggingface import HFLM + +from .config import general_capability_benchmark_configs +from ..base import SystemStatSpecification +from ....logger import get_module_logger +from ....tests.benchmarks import LLMBenchmarkConfig + +logger = get_module_logger("llm_general_capability_spec") + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class LLMGeneralCapabilitySpecification(SystemStatSpecification): + """Large Language Model General Capability Specification""" + + benchmark_configs: List[LLMBenchmarkConfig] = general_capability_benchmark_configs + + def __init__(self): + self.score_dict = None + super(LLMGeneralCapabilitySpecification, self).__init__(type=self.__class__.__name__) + + @staticmethod + def _get_scores(learnware, benchmark_configs: List[LLMBenchmarkConfig]) -> Dict: + """Use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) framework to evaluate learnware according to benchmark_configs and compute score dict. + + Parameters + ---------- + learnware : Learnware + Learnware to generate General Capability Specification. + benchmark_configs : Optional[List[LLMBenchmarkConfig]] + List of LLMBenchmarkConfig. + + Returns + ------- + Dict[LLMBenchmarkConfig, float] + Scores of all benchmark_configs. + """ + learnware.instantiate_model() + base_model = learnware.get_model().get_model() + task_manager = lm_eval.tasks.TaskManager() + + score_dict = {} + for config in benchmark_configs: + try: + lm_obj = HFLM(pretrained=base_model, batch_size="auto") + results = lm_eval.simple_evaluate( + model=lm_obj, + tasks=[config.name], + task_manager=task_manager, + ) + + if config.score_function: + score = config.score_function(results) + else: + score = results["results"][config.name][f"{config.eval_metric},none"] * 100 + score = round(score, 2) + logger.info(f"Name: {config.name}, Score: {score}") + score_dict[config.name] = score + + except Exception as e: + traceback.print_exc() + message = f"Evaluation of {config.name} failed! Due to {repr(e)}." + logger.warning(message) + + return score_dict + + def generate_stat_spec_from_system( + self, + learnware, + benchmark_configs: Optional[List[LLMBenchmarkConfig]] = None, + update_existing: bool = False, + ): + """Construct Large Language Model General Capability Specification for Learnware. + + Parameters + ---------- + learnware : Learnware + Learnware to generate General Capability Specification. + benchmark_configs : Optional[List[LLMBenchmarkConfig]] + List of LLMBenchmarkConfig, set to self.benchmark_configs if None. + update_existing : bool + A flag indicating whether to update existing General Capability Specification's scores dict, by default false. + """ + if benchmark_configs: + for config in benchmark_configs: + if config.eval_metric is None and config.score_function is None: + raise Exception( + "Must specify an evaluation metric or a score computing function in a LLMBenchmarkConfig object to get the evaluation score." + ) + else: + logger.info("No passed benchmark_configs. Set benchmark_configs by default.") + benchmark_configs = self.benchmark_configs + if update_existing: + logger.info("Update existing LLMGeneralCapabilitySpecification.") + self.score_dict = self._get_scores(learnware, benchmark_configs) + else: + existing_config_names = [] + self.score_dict = {} + general_spec = learnware.get_specification().get_stat_spec_by_name("LLMGeneralCapabilitySpecification") + if general_spec: + existing_config_names = list(general_spec.score_dict.keys()) + self.score_dict = general_spec.score_dict.copy() + logger.info("LLMGeneralCapabilitySpecification exists in learnware. Try to update...") + for k, v in general_spec.score_dict.items(): + logger.info(f"Existing scores: Name: {k}, Score: {v}") + new_configs = [config for config in benchmark_configs if config.name not in existing_config_names] + if new_configs: + new_score_dict = self._get_scores(learnware, new_configs) + self.score_dict.update(new_score_dict) + else: + logger.info("All LLMBenchmarkConfig have been evaluated before. No update.") + + def __str__(self): + spec_to_save = self.get_states() + return json.dumps(spec_to_save, separators=(",", ":")) + + def save(self, filepath: str): + """Save the computed specification to a specified path in JSON format. + + Parameters + ---------- + filepath : str + The specified saving path + """ + save_path = filepath + spec_to_save = self.get_states() + with codecs.open(save_path, "w", encoding="utf-8") as fout: + json.dump(spec_to_save, fout, separators=(",", ":")) + + def load(self, filepath: str) -> bool: + """Load a specification file in JSON format from the specified path. + + Parameters + ---------- + filepath : str + The specified loading path. + + Returns + ------- + bool + True if the specification is loaded successfully. + """ + load_path = filepath + if os.path.exists(load_path): + with codecs.open(load_path, "r", encoding="utf-8") as fin: + obj_text = fin.read() + spec_load = json.loads(obj_text) + + for d in self.get_states(): + if d in spec_load.keys(): + if d == "type" and spec_load[d] != self.type: + raise TypeError( + f"The type of loaded Specification ({spec_load[d]}) is different from the expected type ({self.type})!" + ) + setattr(self, d, spec_load[d]) diff --git a/learnware/tests/__init__.py b/learnware/tests/__init__.py index 898b2fec..7e073c99 100644 --- a/learnware/tests/__init__.py +++ b/learnware/tests/__init__.py @@ -1,3 +1,4 @@ +from .benchmarks.config import llm_general_capability_benchmark_configs from .utils import parametrize -__all__ = ["parametrize"] +__all__ = ["parametrize", "llm_general_capability_benchmark_configs"] diff --git a/learnware/tests/benchmarks/__init__.py b/learnware/tests/benchmarks/__init__.py index 436d5aee..609185ee 100644 --- a/learnware/tests/benchmarks/__init__.py +++ b/learnware/tests/benchmarks/__init__.py @@ -3,11 +3,12 @@ import tempfile import zipfile from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np +from datasets import Dataset, load_dataset -from .config import BenchmarkConfig, benchmark_configs +from .config import BenchmarkConfig, LLMBenchmarkConfig, benchmark_configs from ..data import GetData from ...config import C @@ -71,7 +72,79 @@ def get_train_data( return ret -class LearnwareBenchmark: +@dataclass +class LLMBenchmark: + name: str + # HF dataset options + dataset_path: Optional[str] = None + subset_name: Optional[str] = None + dataset_kwargs: Optional[dict] = None + train_split: Optional[str] = None + validation_split: Optional[str] = None + test_split: Optional[str] = None + # evaluation options + eval_metric: Optional[str] = None + score_function: Optional[Callable] = None + # formatting / prompting options + preprocess_function: Optional[Callable] = None + response_template: Optional[str] = None + + def __post_init__(self) -> None: + self.prepare_dataset() + + def prepare_dataset(self) -> None: + self.dataset = load_dataset( + path=self.dataset_path if self.dataset_path else self.name, + name=self.subset_name, + **self.dataset_kwargs if self.dataset_kwargs is not None else {}, + ) + + def get_train_dataset(self) -> Dataset: + if self.train_split: + train_dataset = self.dataset[self.train_split] + if self.dataset_path == "meta-math/GSM8K_zh": + train_dataset = train_dataset.filter(lambda x: x["split"] == "train") + if self.preprocess_function: + train_dataset = train_dataset.map(lambda x: {"text": self.preprocess_function(x)}, batched=True) + return train_dataset + + def get_val_dataset(self) -> Dataset: + if self.validation_split: + val_dataset = self.dataset[self.validation_split] + if self.preprocess_function: + val_dataset = val_dataset.map(lambda x: {"text": self.preprocess_function(x)}, batched=True) + return val_dataset + + def get_test_dataset(self) -> Dataset: + if self.test_split: + test_dataset = self.dataset[self.test_split] + if self.preprocess_function: + test_dataset = test_dataset.map(lambda x: {"text": self.preprocess_function(x)}, batched=True) + return test_dataset + + def get_train_data(self) -> List[str]: + if not self.preprocess_function: + raise Exception("Must specify a preprocess function to get train data!") + train_dataset = self.get_train_dataset() + train_data = train_dataset["text"] + return train_data + + def get_val_data(self) -> List[str]: + if not self.preprocess_function: + raise Exception("Must specify a preprocess function to get validation data!") + val_dataset = self.get_val_dataset() + val_data = val_dataset["text"] + return val_data + + def get_test_data(self) -> List[str]: + if not self.preprocess_function: + raise Exception("Must specify a preprocess function to get test data!") + test_dataset = self.get_test_dataset() + test_data = test_dataset["text"] + return test_data + + +class LearnwareBenchmarkManager: def __init__(self): self.benchmark_configs = benchmark_configs @@ -148,37 +221,53 @@ def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) -> return X_paths, y_paths - def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]) -> Benchmark: + def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig, LLMBenchmarkConfig]) -> Benchmark: if isinstance(benchmark_config, str): benchmark_config = self.benchmark_configs[benchmark_config] - if not isinstance(benchmark_config, BenchmarkConfig): + if not isinstance(benchmark_config, (BenchmarkConfig, LLMBenchmarkConfig)): raise ValueError( "benchmark_config must be a BenchmarkConfig object or a string in benchmark_configs.keys()!" ) - # Load test data - test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test") - - # Load train data - train_X_paths, train_y_paths = None, None - if benchmark_config.train_data_path is not None: - train_X_paths, train_y_paths = self._load_cache_data(benchmark_config, "train") - - # Load extra info - extra_info_path = None - if benchmark_config.extra_info_path is not None: - extra_info_path = os.path.join(C.cache_path, benchmark_config.name, "extra_info") - if not os.path.exists(extra_info_path): - self._download_data(benchmark_config.extra_info_path, extra_info_path) - - return Benchmark( - name=benchmark_config.name, - user_num=benchmark_config.user_num, - learnware_ids=benchmark_config.learnware_ids, - test_X_paths=test_X_paths, - test_y_paths=test_y_paths, - train_X_paths=train_X_paths, - train_y_paths=train_y_paths, - extra_info_path=extra_info_path, - ) + if isinstance(benchmark_config, LLMBenchmarkConfig): + return LLMBenchmark( + name=benchmark_config.name, + dataset_path=benchmark_config.dataset_path, + subset_name=benchmark_config.subset_name, + dataset_kwargs=benchmark_config.dataset_kwargs, + train_split=benchmark_config.train_split, + validation_split=benchmark_config.validation_split, + test_split=benchmark_config.test_split, + eval_metric=benchmark_config.eval_metric, + score_function=benchmark_config.score_function, + preprocess_function=benchmark_config.preprocess_function, + response_template=benchmark_config.response_template, + ) + + elif isinstance(benchmark_config, BenchmarkConfig): + # Load test data + test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test") + + # Load train data + train_X_paths, train_y_paths = None, None + if benchmark_config.train_data_path is not None: + train_X_paths, train_y_paths = self._load_cache_data(benchmark_config, "train") + + # Load extra info + extra_info_path = None + if benchmark_config.extra_info_path is not None: + extra_info_path = os.path.join(C.cache_path, benchmark_config.name, "extra_info") + if not os.path.exists(extra_info_path): + self._download_data(benchmark_config.extra_info_path, extra_info_path) + + return Benchmark( + name=benchmark_config.name, + user_num=benchmark_config.user_num, + learnware_ids=benchmark_config.learnware_ids, + test_X_paths=test_X_paths, + test_y_paths=test_y_paths, + train_X_paths=train_X_paths, + train_y_paths=train_y_paths, + extra_info_path=extra_info_path, + ) diff --git a/learnware/tests/benchmarks/config.py b/learnware/tests/benchmarks/config.py index 3921900f..e8bf60cb 100644 --- a/learnware/tests/benchmarks/config.py +++ b/learnware/tests/benchmarks/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional, Union @dataclass @@ -12,4 +12,24 @@ class BenchmarkConfig: extra_info_path: Optional[str] = None -benchmark_configs: Dict[str, BenchmarkConfig] = {} +@dataclass +class LLMBenchmarkConfig: + name: str + # HF dataset options + dataset_path: Optional[str] = None + subset_name: Optional[str] = None + dataset_kwargs: Optional[dict] = None + train_split: Optional[str] = None + validation_split: Optional[str] = None + test_split: Optional[str] = None + # evaluation options + eval_metric: Optional[str] = None + score_function: Optional[Callable] = None + # formatting / prompting options + preprocess_function: Optional[Callable] = None + response_template: Optional[str] = None + + +benchmark_configs: Dict[str, Union[BenchmarkConfig, LLMBenchmarkConfig]] = {} + +llm_general_capability_benchmark_configs: Dict[str, LLMBenchmarkConfig] = {} diff --git a/learnware/tests/benchmarks/llm_process_funcs.py b/learnware/tests/benchmarks/llm_process_funcs.py new file mode 100644 index 00000000..c55b6d21 --- /dev/null +++ b/learnware/tests/benchmarks/llm_process_funcs.py @@ -0,0 +1,376 @@ +import re +from typing import List + + +def preprocess_alpaca(docs) -> List[str]: + alpaca_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Input:\n{}\n\n### Response:\n{}" + instructions = docs["instruction"] + inputs = docs["input"] + outputs = docs["output"] + texts = [] + for instruction, input, output in zip(instructions, inputs, outputs): + text = alpaca_prompt.format(instruction, input, output) + texts.append(text) + return texts + + +def preprocess_alpaca_no_label(docs) -> List[str]: + alpaca_no_label_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Input:\n{}\n\n### Response:\n" + instructions = docs["instruction"] + inputs = docs["input"] + texts = [] + for instruction, input in zip(instructions, inputs): + text = alpaca_no_label_prompt.format(instruction, input) + texts.append(text) + return texts + + +def preprocess_alpaca_no_input(docs) -> List[str]: + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["instruction"] + outputs = docs["output"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_alpaca_no_input_no_label(docs) -> List[str]: + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["instruction"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_qr(docs) -> List[str]: + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["query"] + outputs = docs["response"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_qr_no_label(docs) -> List[str]: + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["query"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_qr_zh(docs) -> List[str]: + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["query_zh"] + outputs = docs["response_zh"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_qr_zh_no_label(docs) -> List[str]: + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["query_zh"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_qa(docs) -> List[str]: + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["question"] + outputs = docs["answer"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_qa_no_label(docs) -> List[str]: + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["question"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_qa_zh(docs) -> List[str]: + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["question_zh"] + outputs = docs["answer_zh"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_qa_zh_no_label(docs) -> List[str]: + alpaca_no_input_no_label_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n" + instructions = docs["question_zh"] + texts = [] + for instruction in instructions: + text = alpaca_no_input_no_label_prompt.format(instruction) + texts.append(text) + return texts + + +def preprocess_finance(docs) -> List[str]: + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["query"] + outputs = docs["answer"] + texts = [] + for instruction, output in zip(instructions, outputs): + instruction.rstrip(" Answer:") + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_math_train(docs) -> List[str]: + alpaca_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{}\n\n### Response:\n{}" + instructions = docs["question"] + outputs = docs["answer_detail"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = alpaca_no_input_prompt.format(instruction, output) + texts.append(text) + return texts + + +def preprocess_medmcqa_no_label(docs) -> List[str]: + opas = docs["opa"] + opbs = docs["opb"] + opcs = docs["opc"] + opds = docs["opd"] + questions = docs["question"] + texts = [] + for opa, opb, opc, opd, question in zip(opas, opbs, opcs, opds, questions): + option_choices = { + "A": opa, + "B": opb, + "C": opc, + "D": opd, + } + prompt = "Question: " + question + "\nChoices:\n" + for choice, option in option_choices.items(): + prompt += f"{choice.upper()}. {option}\n" + prompt += "Answer:" + texts.append(prompt) + return texts + + +def preprocess_medmcqa(docs) -> List[str]: + opas = docs["opa"] + opbs = docs["opb"] + opcs = docs["opc"] + opds = docs["opd"] + questions = docs["question"] + option_ids = docs["cop"] + texts = [] + for opa, opb, opc, opd, question, option_id in zip(opas, opbs, opcs, opds, questions, option_ids): + option_choices = { + "A": opa, + "B": opb, + "C": opc, + "D": opd, + } + prompt = "Question: " + question + "\nChoices:\n" + for choice, option in option_choices.items(): + prompt += f"{choice.upper()}. {option}\n" + prompt += f"Answer: {list(option_choices.keys())[option_id]}" + texts.append(prompt) + return texts + + +def preprocess_medqa_no_label(docs) -> List[str]: + ending0s = docs["ending0"] + ending1s = docs["ending1"] + ending2s = docs["ending2"] + ending3s = docs["ending3"] + sent1s = docs["sent1"] + texts = [] + for sent1, ending0, ending1, ending2, ending3 in zip(sent1s, ending0s, ending1s, ending2s, ending3s): + option_choices = { + "A": ending0, + "B": ending1, + "C": ending2, + "D": ending3, + } + answers = "".join((f"{k}. {v}\n") for k, v in option_choices.items()) + texts.append(f"Question: {sent1}\n{answers}Answer:") + return texts + + +def preprocess_medqa(docs) -> List[str]: + ending0s = docs["ending0"] + ending1s = docs["ending1"] + ending2s = docs["ending2"] + ending3s = docs["ending3"] + sent1s = docs["sent1"] + labels = docs["label"] + texts = [] + for sent1, ending0, ending1, ending2, ending3, label in zip(sent1s, ending0s, ending1s, ending2s, ending3s, labels): + option_choices = { + "A": ending0, + "B": ending1, + "C": ending2, + "D": ending3, + } + answers = "".join((f"{k}. {v}\n") for k, v in option_choices.items()) + texts.append(f"Question: {sent1}\n{answers}Answer: {list(option_choices.keys())[label]}") + return texts + + +def preprocess_mmlu_no_label(docs) -> List[str]: + questions = docs["question"] + choices = docs["choices"] + texts = [] + for question, options in zip(questions, choices): + texts.append( + "{}\nA. {}\nB. {}\nC. {}\nD. {}\nAnswer:".format( + question.strip(), options[0], options[1], options[2], options[3] + ) + ) + return texts + + +def preprocess_mmlu(docs) -> List[str]: + questions = docs["question"] + choices = docs["choices"] + answers = docs["answer"] + texts = [] + for question, options, answer in zip(questions, choices, answers): + texts.append( + "{}\nA. {}\nB. {}\nC. {}\nD. {}\nAnswer: {}".format( + question.strip(), options[0], options[1], options[2], options[3], ["A", "B", "C", "D"][answer] + ) + ) + return texts + + +def preprocess_pubmedqa_no_label(docs) -> List[str]: + contexts_list = docs["CONTEXTS"] + questions = docs["QUESTION"] + texts = [] + for contexts, question in zip(contexts_list, questions): + ctxs = "\n".join(contexts) + texts.append("Abstract: {}\nQuestion: {}\nAnswer:".format(ctxs, question)) + return texts + + +def preprocess_pubmedqa(docs) -> List[str]: + contexts_list = docs["CONTEXTS"] + questions = docs["QUESTION"] + answers = docs["final_decision"] + texts = [] + for contexts, question, answer in zip(contexts_list, questions, answers): + ctxs = "\n".join(contexts) + texts.append("Abstract: {}\nQuestion: {}\nAnswer: {}".format(ctxs, question, answer)) + return texts + + +def preprocess_agieval_no_label(docs) -> List[str]: + return docs["query"] + + +def preprocess_cmmlu_no_label(docs) -> List[str]: + questions = docs["Question"] + as_ = docs["A"] + bs = docs["B"] + cs = docs["C"] + ds = docs["D"] + texts = [] + for question, a, b, c, d in zip(questions, as_, bs, cs, ds): + texts.append("{}\nA. {}\nB. {}\nC. {}\nD. {}\n答案:".format(question.strip(), a, b, c, d)) + return texts + + +def preprocess_cmmlu(docs) -> List[str]: + questions = docs["Question"] + as_ = docs["A"] + bs = docs["B"] + cs = docs["C"] + ds = docs["D"] + answers = docs["Answer"] + texts = [] + for question, a, b, c, d, answer in zip(questions, as_, bs, cs, ds, answers): + texts.append("{}\nA. {}\nB. {}\nC. {}\nD. {}\n答案:{}".format(question.strip(), a, b, c, d, answer)) + return texts + + +def preprocess_mathqa_no_label(docs) -> List[str]: + problems = docs["Problem"] + texts = [f"Question: {problem}\nAnswer:" for problem in problems] + return texts + + +def preprocess_mathqa(docs) -> List[str]: + problems = docs["Problem"] + corrects = docs["correct"] + options = docs["options"] + texts = [] + for problem, correct, option in zip(problems, corrects, options): + choices = [c[4:].rstrip(" ,") for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", option)] + + # answer = ['a', 'b', 'c', 'd', 'e'].index(correct) + texts.append( + "Question: {}\na. {}\nb. {}\nc. {}\nd. {}\ne. {}\nAnswer: {}".format( + problem, choices[0], choices[1], choices[2], choices[3], choices[4], correct + ) + ) + return texts + + +def preprocess_mgsm_no_label(docs) -> List[str]: + questions = docs["question"] + texts = ["问题: " + question + "\n逐步解答:" for question in questions] + return texts + + +def preprocess_mgsm(docs) -> List[str]: + questions = docs["question"] + answers = docs["answer"] + texts = [question + "\n" + answer for question, answer in zip(questions, answers)] + return texts + + +def preprocess_gsm8k_no_label(docs) -> List[str]: + questions = docs["question"] + texts = [f"Question: {question}\nAnswer:" for question in questions] + return texts + + +def preprocess_gsm8k(docs) -> List[str]: + instructions = docs["question"] + outputs = docs["answer"] + texts = [] + for instruction, output in zip(instructions, outputs): + text = f"Question: {instruction}\nAnswer: {output}" + texts.append(text) + return texts + + +def preprocess_math_no_label(docs) -> List[str]: + problems = docs["problem"] + texts = ["Problem:" + "\n" + problem + "\n\n" + "Solution:" for problem in problems] + return texts + + +def preprocess_finance_no_label(docs) -> List[str]: + return docs["query"] diff --git a/setup.py b/setup.py index 7a485838..c0e66102 100644 --- a/setup.py +++ b/setup.py @@ -48,11 +48,15 @@ def get_version(rel_path: str) -> str: "docker>=6.1.3", "rapidfuzz>=3.4.0", "langdetect>=1.0.9", - "huggingface-hub<0.18", + "huggingface-hub", "transformers>=4.34.1", "portalocker>=2.0.0", "qpsolvers[clarabel]>=4.0.1", "geatpy>=2.7.0;python_version<'3.11'", + "trl>=0.11.4", + "datasets>=2.16.0", + "peft>=0.13.2", + "lm_eval>=0.4.7", ] DEV_REQUIRED = [ @@ -73,11 +77,11 @@ def get_version(rel_path: str) -> str: FULL_REQUIRED = [ # The default full requirements for learnware package - "torch==2.0.1", - "torchvision==0.15.2", + "torch>=2.1.0", + "torchvision>=0.16.0", "torch-optimizer>=0.3.0", "lightgbm>=3.3.0", - "sentence_transformers==2.2.2", + "sentence_transformers==3.2.1", "fast_pytorch_kmeans==0.2.0.1", ] diff --git a/tests/test_specification/test_general_spec.py b/tests/test_specification/test_general_spec.py new file mode 100644 index 00000000..cd6c472d --- /dev/null +++ b/tests/test_specification/test_general_spec.py @@ -0,0 +1,69 @@ +import json +import os +import tempfile +import unittest + +from learnware.specification.system.llm_general_capability_spec.config import test_benchmark_configs +from learnware.specification import LLMGeneralCapabilitySpecification +from learnware.client import LearnwareClient +from learnware.market import instantiate_learnware_market +from learnware.specification import generate_semantic_spec +from learnware.market import LearnwareMarket + +os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + +class TestGeneralCapabilitySpec(unittest.TestCase): + @staticmethod + def _test_general_spec(learnware, benchmark_configs): + spec = LLMGeneralCapabilitySpecification() + spec.generate_stat_spec_from_system(learnware=learnware, benchmark_configs=benchmark_configs) + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + spec_path = os.path.join(tempdir, "general_spec.json") + spec.save(spec_path) + + with open(spec_path, "r") as f: + data = json.load(f) + assert data["type"] == "LLMGeneralCapabilitySpecification" + + spec2 = LLMGeneralCapabilitySpecification() + spec2.load(spec_path) + assert spec2.type == "LLMGeneralCapabilitySpecification" + + def test_general_spec(self): + client = LearnwareClient() + learnware = client.load_learnware(learnware_id="00002681") + self._test_general_spec(learnware, test_benchmark_configs) + + @staticmethod + def _prepare_learnware_market() -> LearnwareMarket: + """initialize learnware market""" + llm_market = instantiate_learnware_market(market_id="llm_test", name="llm", rebuild=True) + semantic_spec = generate_semantic_spec( + name="Qwen/Qwen2.5-0.5B", + description="Qwen/Qwen2.5-0.5B", + data_type="Text", + model_type="Base Model", + task_type="Text Generation", + library_type="PyTorch", + scenarios=["Others"], + license="Others", + input_description=None, + output_description=None, + ) + client = LearnwareClient() + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + zip_path = os.path.join(tempdir, "learnware.zip") + client.download_learnware(learnware_id="00002681", save_path=zip_path) + llm_market.add_learnware(zip_path, semantic_spec) + return llm_market + + def test_in_checker_organizer(self): + llm_market = self._prepare_learnware_market() + learnware_ids = llm_market.get_learnware_ids() + llm_market.learnware_organizer._update_learnware_general_capability_spec(learnware_ids) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_specification/test_text_generative.py b/tests/test_specification/test_text_generative.py new file mode 100644 index 00000000..1b26077c --- /dev/null +++ b/tests/test_specification/test_text_generative.py @@ -0,0 +1,110 @@ +import os +import tempfile +import unittest + +import torch + + +from learnware.learnware.base import Learnware +from learnware.market.llm import LLMStatSearcher +from learnware.specification.base import Specification +from learnware.specification.module import generate_generative_model_spec +from learnware.specification.regular.text import GenerativeModelSpecification + +from text_generative_utils import DATASET, prepare_data + + +class TestGenerativeModelSpecification(unittest.TestCase): + @staticmethod + def _test_with_X(X): + spec = GenerativeModelSpecification() + spec.generate_stat_spec_from_data(X=X, dataset_text_field="txt") + + task_vector = spec.task_vector + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + spec_path = os.path.join(tempdir, "spec.pth") + spec.save(spec_path) + + data = torch.load(spec_path, weights_only=True) + assert data["type"] == "GenerativeModelSpecification" + + spec2 = GenerativeModelSpecification() + spec2.load(spec_path) + + torch.testing.assert_close(task_vector.cpu(), spec2.task_vector.cpu()) + + assert spec2.type == "GenerativeModelSpecification" + + @staticmethod + def _test_with_dataset(dataset): + spec = GenerativeModelSpecification() + spec.generate_stat_spec_from_data(dataset=dataset) + + task_vector = spec.task_vector + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + spec_path = os.path.join(tempdir, "spec.pth") + spec.save(spec_path) + + data = torch.load(spec_path, weights_only=True) + assert data["type"] == "GenerativeModelSpecification" + + spec2 = GenerativeModelSpecification() + spec2.load(spec_path) + + torch.testing.assert_close(task_vector.cpu(), spec2.task_vector.cpu()) + assert spec2.type == "GenerativeModelSpecification" + + @staticmethod + def _test_with_generating_directly(X): + spec = generate_generative_model_spec(X=X, dataset_text_field="name") + + task_vector = spec.task_vector + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + spec_path = os.path.join(tempdir, "spec.pth") + spec.save(spec_path) + + data = torch.load(spec_path, weights_only=True) + assert data["type"] == "GenerativeModelSpecification" + + spec2 = GenerativeModelSpecification() + spec2.load(spec_path) + + torch.testing.assert_close(task_vector.cpu(), spec2.task_vector.cpu()) + assert spec2.type == "GenerativeModelSpecification" + + def test_loading_from_bwm(self): + spec = GenerativeModelSpecification() + _, model1 = spec._init_tokenizer_model(True) + _, model2 = spec._init_tokenizer_model(False) + + params1, params2 = dict(model1.named_parameters()), dict(model2.named_parameters()) + for k in model1.state_dict(): + torch.testing.assert_close(params1[k].cpu(), params2[k].cpu()) + + def test_generating_spec(self): + train_dataset = prepare_data(DATASET["pubmedqa"]) + + self._test_with_X(train_dataset["text"]) + self._test_with_dataset(train_dataset) + self._test_with_dataset(train_dataset, beimingwu=False) + + def test_searching_spec(self): + specs, learnwares = [], [] + for i, dataset_name in enumerate(["pubmedqa", "medmcqa"]): + train_dataset = prepare_data(DATASET[dataset_name]) + + spec = GenerativeModelSpecification() + spec.generate_stat_spec_from_data(dataset=train_dataset) + + specs.append(spec) + learnwares.append(Learnware(str(i), None, Specification(stat_spec={spec.type: spec}), "")) + + searcher = LLMStatSearcher(None) + searcher._search_by_taskvector_spec_single(learnwares, specs[-1], specs[-1].type) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_specification/text_generative_utils.py b/tests/test_specification/text_generative_utils.py new file mode 100644 index 00000000..83256513 --- /dev/null +++ b/tests/test_specification/text_generative_utils.py @@ -0,0 +1,60 @@ +from datasets import load_dataset + +DATASET = { + "medmcqa": "openlifescienceai/medmcqa", + "pubmedqa": "bigbio/pubmed_qa,pubmed_qa_labeled_fold0_source", +} + + +def preprocess_medmcqa(doc) -> str: + """ + Question: + Choices: + A. + B. + C. + D. + Answer: + """ + choices = [doc["opa"], doc["opb"], doc["opc"], doc["opd"]] + option_choices = { + "A": choices[0], + "B": choices[1], + "C": choices[2], + "D": choices[3], + } + + prompt = "Question: " + doc["question"] + "\nChoices:\n" + for choice, option in option_choices.items(): + prompt += f"{choice.upper()}. {option}\n" + prompt += "Answer:" + return prompt + + +def preprocess_pubmedqa(doc) -> str: + ctxs = "\n".join(doc["CONTEXTS"]) + return "Abstract: {}\nQuestion: {}\nAnswer:".format( + ctxs, + doc["QUESTION"], + ) + + +PROCESS_FUNC = { + # medical user + "openlifescienceai/medmcqa": preprocess_medmcqa, + "bigbio/pubmed_qa": preprocess_pubmedqa, +} + + +def prepare_data(dataset_name_str): + temp_list = dataset_name_str.split(",") + subset_name = None + if len(temp_list) != 1: + subset_name = temp_list[1] + dataset_name = temp_list[0] + if subset_name: + test_dataset = load_dataset(dataset_name, subset_name, split="test", trust_remote_code=True) + else: + test_dataset = load_dataset(dataset_name, split="test", trust_remote_code=True) + test_dataset = test_dataset.map(lambda x: {"text": PROCESS_FUNC[dataset_name](x)}) + return test_dataset