diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..9f43258 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,52 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build: + runs-on: windows-latest + + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12'] + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Miniconda on Windows + uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + + - name: Create Conda environment + run: | + conda env create -f environment.yaml + + - name: Activate Conda environment and run tests + run: | + conda activate MED3pa + python -m unittest discover -s tests + shell: cmd + + - name: Install dependencies for documentation + run: | + conda activate MED3pa + conda install sphinx sphinx_rtd_theme + shell: cmd + + - name: Build documentation + run: | + conda activate MED3pa + cd docs + make html + shell: cmd \ No newline at end of file diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..03d3cf7 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,63 @@ +name: Upload Python Package to PyPI and Deploy Documentation to Read the Docs + +on: + release: + types: [created] + +jobs: + pypi-publish: + name: Publish release to PyPI and Deploy Documentation + runs-on: ubuntu-latest + + permissions: + id-token: write + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.12" + + - name: Install Miniconda + uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + + - name: Initialize Conda + run: | + conda init bash + shell: bash + + - name: Create Conda environment + run: | + conda env create -f environment.yaml + + - name: Activate Conda environment + run: | + source ~/.bashrc + conda activate MED3pa + shell: bash + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel setuptools-scm sphinx sphinx_rtd_theme + + - name: Clean dist directory + run: | + rm -rf dist/ + + - name: Build package + run: | + python setup.py sdist bdist_wheel # Could also be python -m build + + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://upload.pypi.org/legacy/ + skip-existing: true + \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e54b700 --- /dev/null +++ b/.gitignore @@ -0,0 +1,34 @@ +# Ignore Python bytecode files +*.pyc +*.pyo +__pycache__/ + +# Ignore Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Ignore virtual environment directories +venv/ +env/ +.venv/ +.env/ + +# Ignore OS-specific files +.DS_Store +Thumbs.db + +# Ignore IDE-specific files +.vscode/ +.idea/ + +# Ignore log files +*.log + +# Ignore specific directories or files +node_modules/ +dist/ +build/ + +experiments/ + +*.MED3paResults +*.csv \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f288702 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/MED3pa/__init__.py b/MED3pa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MED3pa/datasets/__init__.py b/MED3pa/datasets/__init__.py new file mode 100644 index 0000000..af60b1a --- /dev/null +++ b/MED3pa/datasets/__init__.py @@ -0,0 +1,6 @@ +from .loading_context import * +from .loading_strategies import * +from .manager import * +from .masked import * + + diff --git a/MED3pa/datasets/loading_context.py b/MED3pa/datasets/loading_context.py new file mode 100644 index 0000000..7e0ada5 --- /dev/null +++ b/MED3pa/datasets/loading_context.py @@ -0,0 +1,78 @@ +""" +This module provides a flexible framework for loading datasets from various file formats by utilizing the **strategy design pattern**. +It supports dynamic selection of data loading strategies based on the file extension, enabling easy extension and maintenance. +It includes the ``DataLoadingContext`` class, responsible for selecting and setting the right **loading strategy** based on the loaded file extension. +""" +import numpy as np +from typing import Tuple, List + +from .loading_strategies import DataLoadingStrategy, CSVDataLoadingStrategy + + +class DataLoadingContext: + """ + A context class for managing data loading strategies. It supports setting and getting the current + data loading strategy, as well as loading data as a NumPy array from a specified file. + """ + + strategies = { + 'csv': CSVDataLoadingStrategy, + } + + def __init__(self, file_path: str): + """ + Initializes the data loading context with a strategy based on the file extension. + + Args: + file_path (str): The path to the dataset file. + + Raises: + ValueError: If the file extension is not supported. + """ + file_extension = file_path.split('.')[-1] + strategy_class = self.strategies.get(file_extension, None) + if strategy_class is None: + raise ValueError(f"This file extension is not supported yet: '{file_extension}'") + self.selected_strategy = strategy_class() + + def set_strategy(self, strategy: DataLoadingStrategy) -> None: + """ + Sets a new data loading strategy. + + Args: + strategy (DataLoadingStrategy): The new data loading strategy to be used. + """ + self.selected_strategy = strategy + + def get_strategy(self) -> DataLoadingStrategy: + """ + Returns the currently selected data loading strategy. + + Returns: + DataLoadingStrategy: The currently selected data loading strategy. + """ + return self.selected_strategy + + def load_as_np(self, file_path: str, target_column_name: str) -> Tuple[List[str], np.ndarray, np.ndarray]: + """ + Loads data from the given file path and returns it as a NumPy array, along with column labels and the target data. + + Args: + file_path (str): The path to the dataset file. + target_column_name (str): The name of the target column, such as true labels or values in case of regression. + + Returns: + Tuple[List[str], np.ndarray, np.ndarray]: A tuple containing the column labels, observations as a NumPy array, + and the target as a NumPy array. + """ + return self.selected_strategy.execute(file_path, target_column_name) + + +def supported_file_formats() -> List[str]: + """ + Returns a list of supported file formats. + + Returns: + List[str]: A list of supported file formats. + """ + return list(DataLoadingContext.strategies) diff --git a/MED3pa/datasets/loading_strategies.py b/MED3pa/datasets/loading_strategies.py new file mode 100644 index 0000000..0e1e4fb --- /dev/null +++ b/MED3pa/datasets/loading_strategies.py @@ -0,0 +1,70 @@ +""" +This module provides strategies for loading data from files into usable Python formats, focusing on converting data into **NumPy** arrays. +It includes an abstract base class ``DataLoadingStrategy`` for defining common interfaces and concrete implementations of this class, such as ``CSVDataLoadingStrategy`` for handling CSV files. +This setup allows easy extension to support additional file types as needed. +""" + +import numpy as np +import pandas as pd +from typing import List, Tuple +from abc import ABC, abstractmethod + + +class DataLoadingStrategy(ABC): + """ + Abstract base class for data loading strategies. Defines a common interface for all data loading strategies. + """ + @staticmethod + @abstractmethod + def execute(path_to_file: str, target_column_name: str) -> Tuple[List[str], np.ndarray, np.ndarray]: + """ + Abstract method to execute the data loading strategy. + + Args: + path_to_file (str): The path to the file to be loaded. + target_column_name (str): The name of the target column in the dataset. + + Returns: + Tuple[List[str], np.ndarray, np.ndarray]: A tuple containing the column labels, observations as a NumPy array, + and the target as a NumPy array. + + """ + pass + + +class CSVDataLoadingStrategy(DataLoadingStrategy): + """ + Strategy class for loading CSV data. Implements the abstract execute method to handle CSV files. + + Methods: + execute(path_to_file: str, target_column_name: str) -> Tuple[List[str], np.ndarray, np.ndarray]: + Loads CSV data from the given path, separates observations and target, and converts them to NumPy arrays. + """ + + @staticmethod + def execute(path_to_file: str, target_column_name: str) -> Tuple[List[str], np.ndarray, np.ndarray]: + """ + Loads CSV data from the given path, separates observations and target, and converts them to NumPy arrays. + + Args: + path_to_file (str): The path to the CSV file to be loaded. + target_column_name (str): The name of the target column in the dataset. + + Returns: + Tuple[List[str], np.ndarray, np.ndarray]: Column labels, observations as a NumPy array, and target as a NumPy array. + """ + # Read the CSV file + df = pd.read_csv(path_to_file) + + # Separate observations and target + observations = df.drop(columns=[target_column_name]) + target = df[target_column_name] + column_labels = observations.columns.tolist() + + # Convert to NumPy arrays + obs_np = observations.to_numpy() + target_np = target.to_numpy() + + return column_labels, obs_np, target_np + + diff --git a/MED3pa/datasets/manager.py b/MED3pa/datasets/manager.py new file mode 100644 index 0000000..f87a34f --- /dev/null +++ b/MED3pa/datasets/manager.py @@ -0,0 +1,309 @@ +""" +The manager.py module manages the different datasets needed for machine learning workflows, particularly for ``+tron`` and ``Med3pa`` methods. +It includes the ``DatasetsManager`` class that contains the training, validation, reference, and testing datasets for a specific ML task. +""" + +import numpy as np +import pandas as pd +from typing import Union, List + +from .loading_context import DataLoadingContext +from .masked import MaskedDataset + + +class DatasetsManager: + """ + Manages various datasets for execution of med3pa methods. + + This manager is responsible for loading and holding different sets of data, including training, validation, + reference (or domain dataset), and testing datasets (or new encountered data). + """ + + def __init__(self): + """Initializes the DatasetsManager with empty datasets.""" + self.base_model_training_set = None + self.base_model_validation_set = None + self.reference_set = None + self.testing_set = None + self.column_labels = None + + def set_from_file(self, dataset_type: str, file: str, target_column_name: str) -> None: + """ + Loads and sets the specified dataset from a file. + + Args: + dataset_type (str): The type of dataset to set ('training', 'validation', 'reference', 'testing'). + file (str): The file path to the data. + target_column_name (str): The name of the target column in the dataset. + + Raises: + ValueError: If an invalid dataset_type is provided or if the shape of observations does not match column labels. + """ + ctx = DataLoadingContext(file) + column_labels, obs_np, true_labels_np = ctx.load_as_np(file, target_column_name) + + self.set_column_labels(column_labels) + + # Check if the number of columns in observations matches the length of column_labels + if obs_np.shape[1] != len(self.column_labels): + raise ValueError( + f"The shape of observations {obs_np.shape} does not match the length of column labels {len(column_labels)}") + + dataset = MaskedDataset(obs_np, true_labels_np, column_labels=self.column_labels) + dataset.set_file_path(file=file) + + mapping = { + 'training': 'base_model_training_set', + 'validation': 'base_model_validation_set', + 'reference': 'reference_set', + 'testing': 'testing_set', + } + + try: + setattr(self, mapping[dataset_type], dataset) + except KeyError: + raise ValueError(f"Invalid dataset_type provided: {dataset_type} \n" + f"Available datasets are: {list(mapping)}") + + def set_from_data(self, dataset_type: str, observations: np.ndarray, true_labels: np.ndarray, + column_labels: Union[List, pd.Index] = None) -> None: + """ + Sets the specified dataset using numpy arrays for observations and true labels. + + Args: + dataset_type (str): The type of dataset to set ('training', 'validation', 'reference', 'testing'). + observations (np.ndarray): The feature vectors of the dataset. + true_labels (np.ndarray): The true labels of the dataset. + column_labels (list, optional): The list of column labels for the dataset. Defaults to None. + + Raises: + ValueError: If an invalid dataset_type is provided or if column labels do not match existing column labels. + ValueError: If column_labels and target_column_name are not provided when column_labels are not set. + """ + if column_labels is not None: + if type(column_labels) is pd.Index: + column_labels = column_labels.tolist() + self.set_column_labels(column_labels) + elif self.column_labels is None: + raise ValueError("Column labels must be provided when setting a dataset for the first time.") + + dataset = MaskedDataset(observations, true_labels, column_labels=self.column_labels) + + mapping = { + 'training': 'base_model_training_set', + 'validation': 'base_model_validation_set', + 'reference': 'reference_set', + 'testing': 'testing_set', + } + + try: + setattr(self, mapping[dataset_type], dataset) + except KeyError: + raise ValueError(f"Invalid dataset_type provided: {dataset_type}, Available datasets are: {list(mapping)}") + + def set_column_labels(self, columns: list) -> None: + """ + Sets the column labels for the datasets, excluding the target column. + + Args: + columns (list): The list of columns excluding the target column. + + Raises: + ValueError: If the target column is not found in the list of columns. + """ + + if self.column_labels is None: + self.column_labels = columns + elif not np.array_equal(self.column_labels, columns): + raise ValueError("Provided column labels do not match the existing column labels.") + + for dataset in ( + self.base_model_training_set, + self.base_model_validation_set, + self.reference_set, + self.testing_set, + ): + if dataset is not None: + dataset.column_labels = columns + + def get_column_labels(self): + """ + Retrieves the column labels of the manager + + Returns: + List[str]: A list of the column labels extracted from the files. + + """ + return self.column_labels + + def get_info(self, show_details: bool = True) -> dict: + """ + Returns information about all the datasets managed by the DatasetsManager. + + Args: + show_details (bool): If True, includes detailed information about each dataset. If False, only indicates whether each dataset is set. + + Returns: + dict: A dictionary containing information about each dataset. + """ + if show_details: + datasets_info = { + 'training_set': self.base_model_training_set.get_info() if self.base_model_training_set else 'Not set', + 'validation_set': self.base_model_validation_set.get_info() if self.base_model_validation_set else 'Not set', + 'reference_set': self.reference_set.get_info() if self.reference_set else 'Not set', + 'testing_set': self.testing_set.get_info() if self.testing_set else 'Not set', + 'column_labels': self.column_labels if self.column_labels else 'Not set' + } + else: + datasets_info = { + 'training_set': 'Set' if self.base_model_training_set else 'Not set', + 'validation_set': 'Set' if self.base_model_validation_set else 'Not set', + 'reference_set': 'Set' if self.reference_set else 'Not set', + 'testing_set': 'Set' if self.testing_set else 'Not set', + 'column_labels': 'Set' if self.column_labels else 'Not set' + } + return datasets_info + + def summarize(self) -> None: + """ + Prints a summary of the manager. + """ + info = self.get_info() + print(f"training_set: {info['training_set']}") + print(f"validation_set: {info['validation_set']}") + print(f"reference_set: {info['reference_set']}") + print(f"testing_set: {info['testing_set']}") + print(f"column_labels: {info['column_labels']}") + + def reset_datasets(self) -> None: + """ + Resets all datasets in the manager. + """ + self.base_model_training_set = None + self.base_model_validation_set = None + self.reference_set = None + self.testing_set = None + self.column_labels = None + + def get_dataset_by_type(self, dataset_type: str, return_instance: bool = False) -> Union[tuple, MaskedDataset]: + """ + Helper method to get a dataset by type. + + Args: + dataset_type (str): The type of dataset to retrieve ('training', 'validation', 'reference', 'testing'). + return_instance (bool): If True, returns the MaskedDataset instance; otherwise, returns the observations and + true labels. Defaults to False. + + Returns: + MaskedDataset: The corresponding MaskedDataset instance. + + Raises: + ValueError: If an invalid dataset_type is provided. + """ + if dataset_type == 'training': + return self.__get_base_model_training_data(return_instance=return_instance) + elif dataset_type == 'validation': + return self.__get_base_model_validation_data(return_instance=return_instance) + elif dataset_type == 'reference': + return self.__get_reference_data(return_instance=return_instance) + elif dataset_type == 'testing': + return self.__get_testing_data(return_instance=return_instance) + else: + raise ValueError(f"Invalid dataset_type provided: {dataset_type}") + + def save_dataset_to_csv(self, dataset_type: str, file_path: str) -> None: + """ + Saves the specified dataset to a CSV file. + + Args: + dataset_type (str): The type of dataset to save ('training', 'validation', 'reference', 'testing'). + file_path (str): The file path to save the dataset to. + + Raises: + ValueError: If an invalid dataset_type is provided. + """ + dataset = self.get_dataset_by_type(dataset_type, True) + if dataset is None: + raise ValueError(f"Dataset '{dataset_type}' is not set.") + + dataset.save_to_csv(file_path) + + def __get_base_model_training_data(self, return_instance: bool = False) -> Union[tuple, MaskedDataset]: + """ + Retrieves the training dataset. + + Args: + return_instance (bool, optional): If True, returns the MaskedDataset instance; otherwise, returns the observations and true labels. Defaults to False. + + Returns: + Union[tuple, MaskedDataset]: The observations and true labels if return_instance is False, otherwise the MaskedDataset instance. + + Raises: + ValueError: If the base model training set is not initialized. + """ + if self.base_model_training_set is not None: + if return_instance: + return self.base_model_training_set + return self.base_model_training_set.get_observations(), self.base_model_training_set.get_true_labels() + else: + raise ValueError("Base model training set not initialized.") + + def __get_base_model_validation_data(self, return_instance: bool = False) -> Union[tuple, MaskedDataset]: + """ + Retrieves the validation dataset. + + Args: + return_instance (bool, optional): If True, returns the MaskedDataset instance; otherwise, returns the observations and true labels. Defaults to False. + + Returns: + Union[tuple, MaskedDataset]: The observations and true labels if return_instance is False, otherwise the MaskedDataset instance. + + Raises: + ValueError: If the base model validation set is not initialized. + """ + if self.base_model_validation_set is not None: + if return_instance: + return self.base_model_validation_set + return self.base_model_validation_set.get_observations(), self.base_model_validation_set.get_true_labels() + else: + raise ValueError("Base model validation set not initialized.") + + def __get_reference_data(self, return_instance: bool = False) -> Union[tuple, MaskedDataset]: + """ + Retrieves the reference dataset. + + Args: + return_instance (bool, optional): If True, returns the MaskedDataset instance; otherwise, returns the observations and true labels. Defaults to False. + + Returns: + Union[tuple, MaskedDataset]: The observations and true labels if return_instance is False, otherwise the MaskedDataset instance. + + Raises: + ValueError: If the reference set is not initialized. + """ + if self.reference_set is not None: + if return_instance: + return self.reference_set + return self.reference_set.get_observations(), self.reference_set.get_true_labels() + else: + raise ValueError("Reference set not initialized.") + + def __get_testing_data(self, return_instance: bool = False) -> Union[tuple, MaskedDataset]: + """ + Retrieves the testing dataset. + + Args: + return_instance (bool, optional): If True, returns the MaskedDataset instance; otherwise, returns the observations and true labels. Defaults to False. + + Returns: + Union[tuple, MaskedDataset]: The observations and true labels if return_instance is False, otherwise the MaskedDataset instance. + + Raises: + ValueError: If the testing set is not initialized. + """ + if self.testing_set is not None: + if return_instance: + return self.testing_set + return self.testing_set.get_observations(), self.testing_set.get_true_labels() + else: + raise ValueError("Testing set not initialized.") diff --git a/MED3pa/datasets/masked.py b/MED3pa/datasets/masked.py new file mode 100644 index 0000000..be8f4cc --- /dev/null +++ b/MED3pa/datasets/masked.py @@ -0,0 +1,384 @@ +""" +The masked.py module includes the ``MaskedDataset`` class that is capable of handling many dataset related operations, +such as cloning, sampling, refining, etc. +""" + +import numpy as np +import pandas as pd + +from torch.utils.data import Dataset + + +class MaskedDataset(Dataset): + """ + A dataset wrapper for PyTorch that supports masking and sampling of data points. + """ + + def __init__(self, observations: np.ndarray, true_labels: np.ndarray, column_labels: list = None): + """ + Initializes the MaskedDataset. + + Args: + observations (np.ndarray): The observations vectors of the dataset. + true_labels (np.ndarray): The true labels of the dataset. + column_labels (list, optional): The column labels for the observation vectors. Defaults to None. + """ + self.__observations = observations + self.__true_labels = true_labels + self.__indices = np.arange(len(self.__observations)) + self.__original_indices = self.__indices.copy() + self.__sample_counts = np.zeros(len(observations), dtype=int) + self.__pseudo_probabilities = None + self.__pseudo_labels = None + self.__confidence_scores = None + self.__column_labels = column_labels if column_labels is not None else [f'feature_{i}' for i in range(observations.shape[1])] + self.__file_path = None + + def __getitem__(self, index: int) -> tuple: + """ + Retrieves the data point and its label(s) at the given index. + + Args: + index (int): The index of the data point. + + Returns: + tuple: A tuple containing the observation vector, pseudo label, and true label. + """ + index = self.__indices[index] + x = self.__observations[index] + y = self.__true_labels[index] + y_hat = self.__pseudo_labels[index] if self.__pseudo_labels is not None else None + return x, y_hat, y + + def __len__(self) -> int: + """ + Gets the number of data points in the dataset. + + Returns: + int: The number of data points. + """ + return len(self.__indices) + + def refine(self, mask: np.ndarray) -> int: + """ + Refines the dataset by applying a mask to select specific data points. + + Args: + mask (np.ndarray): A boolean array indicating which data points to keep. + + Returns: + int: The number of data points remaining after applying the mask. + + Raises: + ValueError: If the length of the mask doesn't match the number of data points. + """ + if len(mask) != len(self.__observations): + raise ValueError("Mask length must match the number of data points.") + + self.__indices = self.__indices[mask] + self.__observations = self.__observations[mask] + self.__true_labels = self.__true_labels[mask] + if self.__pseudo_labels is not None: + self.__pseudo_labels = self.__pseudo_labels[mask] + if self.__pseudo_probabilities is not None: + self.__pseudo_probabilities = self.__pseudo_probabilities[mask] + if self.__confidence_scores is not None: + self.__confidence_scores = self.__confidence_scores[mask] + if self.__sample_counts is not None: + self.__sample_counts = self.__sample_counts[mask] + + return len(self.__observations) + + def reset_indices(self) -> None: + """Resets the indices of the dataset to the original indices.""" + self.__indices = self.__original_indices.copy() + + def sample_uniform(self, n_samples: int, seed: int) -> 'MaskedDataset': + """ + Samples N data points from the dataset, prioritizing the least sampled points. + + Args: + n_samples (int): The number of samples to return. + seed (int): The seed for random number generator. + + Returns: + MaskedDataset: A new instance of the dataset containing N random samples. + + Raises: + ValueError: If N is greater than the current number of data points in the dataset. + """ + if n_samples > len(self.__observations): + raise ValueError("N cannot be greater than the current number of data points in the dataset.") + + # Find the indices of the least sampled points + sorted_indices = np.argsort(self.__sample_counts) + least_sampled_indices = sorted_indices[:n_samples] + + # Set the seed for reproducibility and shuffle the least sampled indices + np.random.seed(seed) + np.random.shuffle(least_sampled_indices) + + # Select the first N after shuffling + sampled_indices = least_sampled_indices[:n_samples] + # Update the sample counts for the sampled indices + self.__sample_counts[sampled_indices] += 1 + + # Extract the sampled observations and labels + sampled_set = self.__sample_indices(sampled_indices) + + return sampled_set + + def sample_random(self, n_samples: int, seed: int) -> 'MaskedDataset': + """ + Samples N data points randomly from the dataset using the given seed. + + Args: + n_samples (int): The number of samples to return. + seed (int): The seed for random number generator. + + Returns: + MaskedDataset: A new instance of the dataset containing N random samples. + + Raises: + ValueError: If N is greater than the current number of data points in the dataset. + """ + if n_samples > len(self.__observations): + raise ValueError("N cannot be greater than the current number of data points in the dataset.") + + # Set the seed for reproducibility and generate random indices + rng = np.random.RandomState(seed) + random_indices = rng.permutation(len(self.__observations))[:n_samples] + + # Extract the sampled observations and labels + sampled_set = self.__sample_indices(random_indices) + + return sampled_set + + def __sample_indices(self, indices: np.ndarray) -> 'MaskedDataset': + """ + Samples data points from the dataset using the given indices. + + Args: + indices (np.ndarray): The indices of samples to return. + + Returns: + MaskedDataset: A new instance of the dataset containing samples corresponding to the given indices. + """ + # Extract the sampled observations and labels + sampled_obs = self.__observations[indices, :] + sampled_true_labels = self.__true_labels[indices] + sampled_pseudo_labels = self.__pseudo_labels[indices] if self.__pseudo_labels is not None else None + sampled_confidence_scores = self.__confidence_scores[ + indices] if self.__confidence_scores is not None else None + sampled_pseudo_probs = self.__pseudo_probabilities[ + indices] if self.__pseudo_probabilities is not None else None + + # Return a new MaskedDataset instance containing the sampled data + sampled_set = MaskedDataset(observations=sampled_obs, true_labels=sampled_true_labels, + column_labels=self.__column_labels) + sampled_set.set_pseudo_probs_labels(sampled_pseudo_probs) if sampled_pseudo_probs is not None else None + sampled_set.set_pseudo_labels(sampled_pseudo_labels) if sampled_pseudo_labels is not None else None + sampled_set.set_confidence_scores(sampled_confidence_scores) if sampled_confidence_scores is not None else None + + return sampled_set + + def get_observations(self) -> np.ndarray: + """ + Gets the observations vectors of the dataset. + + Returns: + np.ndarray: The observations vectors of the dataset. + """ + return self.__observations + + def get_pseudo_labels(self) -> np.ndarray: + """ + Gets the pseudo labels of the dataset. + + Returns: + np.ndarray: The pseudo labels of the dataset. + """ + return self.__pseudo_labels + + def get_true_labels(self) -> np.ndarray: + """ + Gets the true labels of the dataset. + + Returns: + np.ndarray: The true labels of the dataset. + """ + return self.__true_labels + + def get_pseudo_probabilities(self) -> np.ndarray: + """ + Gets the pseudo probabilities of the dataset. + + Returns: + np.ndarray: The pseudo probabilities of the dataset. + """ + return self.__pseudo_probabilities + + def get_confidence_scores(self) -> np.ndarray: + """ + Gets the confidence scores of the dataset. + + Returns: + np.ndarray: The confidence scores of the dataset. + """ + return self.__confidence_scores + + def get_sample_counts(self) -> np.ndarray: + """ + Gets the how many times each element of the dataset was sampled. + + Returns: + np.ndarray: The sample counts of the dataset. + """ + return self.__sample_counts + + def get_file_path(self) -> str : + """ + Gets the file path of the dataset if it has been set from a file. + + Returns: + str: The file path of the dataset. + """ + return self.__file_path + + def set_pseudo_probs_labels(self, pseudo_probabilities: np.ndarray, threshold: float = None) -> None: + """ + Sets the pseudo probabilities and corresponding pseudo labels for the dataset. The labels are derived by + applying a threshold to the probabilities. + + Args: + pseudo_probabilities (np.ndarray): The pseudo probabilities array to be set. + threshold (float): The threshold to convert probabilities to binary labels. + + Raises: + ValueError: If the shape of pseudo_probabilities does not match the number of samples in the observations array. + """ + if pseudo_probabilities.shape[0] != self.__observations.shape[0]: + raise ValueError("The shape of pseudo_probabilities must match the number of samples in the observations array.") + + self.__pseudo_probabilities = pseudo_probabilities + if threshold: + self.__pseudo_labels = pseudo_probabilities >= threshold + + def set_confidence_scores(self, confidence_scores: np.ndarray) -> None: + """ + Sets the confidence scores for the dataset. + + Args: + confidence_scores (np.ndarray): The confidence scores array to be set. + + Raises: + ValueError: If the shape of confidence_scores does not match the number of samples in the observations array. + """ + if confidence_scores.shape[0] != self.__observations.shape[0]: + raise ValueError("The shape of confidence_scores must match the number of samples in the observations array.") + + self.__confidence_scores = confidence_scores + + def set_pseudo_labels(self, pseudo_labels: np.ndarray) -> None: + """ + Adds pseudo labels to the dataset. + + Args: + pseudo_labels (np.ndarray): The pseudo labels to add. + + Raises: + ValueError: If the length of pseudo_labels does not match the number of samples. + """ + if len(pseudo_labels) != len(self.__observations): + raise ValueError("The length of pseudo_labels must match the number of samples in the dataset.") + self.__pseudo_labels = pseudo_labels + + def set_file_path(self, file: str) -> None: + """ + Sets the file path of the dataset if it has been set from a file. + + Args: + file (str): The file path of the dataset. + + """ + self.__file_path = file + + def clone(self) -> 'MaskedDataset': + """ + Creates a clone of the current MaskedDataset instance. + + Returns: + MaskedDataset: A new instance of MaskedDataset containing the same data and configurations as the current instance. + """ + cloned_set = MaskedDataset(observations=self.__observations.copy(), true_labels=self.__true_labels.copy(), column_labels=self.__column_labels) + cloned_set.__pseudo_labels = self.__pseudo_labels.copy() if self.__pseudo_labels is not None else None + cloned_set.__pseudo_probabilities = self.__pseudo_probabilities.copy() if self.__pseudo_probabilities is not None else None + cloned_set.__confidence_scores = self.__confidence_scores.copy() if self.__confidence_scores is not None else None + + return cloned_set + + def get_info(self) -> dict: + """ + Returns information about the MaskedDataset. + + Returns: + dict: A dictionary containing dataset information. + """ + info = { + 'file_path': self.__file_path, + 'num_samples': len(self.__observations), + 'num_observations': self.__observations.shape[1] if self.__observations.ndim > 1 else 1, + 'has_pseudo_labels': self.__pseudo_labels is not None, + 'has_pseudo_probabilities': self.__pseudo_probabilities is not None, + 'has_confidence_scores': self.__confidence_scores is not None, + } + return info + + def summarize(self) -> None: + """ + Prints a summary of the dataset. + """ + info = self.get_info() + print(f"Number of samples: {info['num_samples']}") + print(f"Number of observations: {info['num_observations']}") + print(f"Has pseudo labels: {info['has_pseudo_labels']}") + print(f"Has pseudo probabilities: {info['has_pseudo_probabilities']}") + print(f"Has confidence scores: {info['has_confidence_scores']}") + + def to_dataframe(self) -> pd.DataFrame: + """ + Converts the dataset to a pandas DataFrame. + + Returns: + pd.DataFrame: The dataset as a pandas DataFrame. + """ + # Convert observations to DataFrame + data = self.__observations.copy() + df = pd.DataFrame(data, columns=self.__column_labels) + + # Add true labels + df['true_labels'] = self.__true_labels + + # Add pseudo labels if available + if self.__pseudo_labels is not None: + df['pseudo_labels'] = self.__pseudo_labels + + # Add pseudo probabilities if available + if self.__pseudo_probabilities is not None: + df[f'pseudo_probabilities'] = self.__pseudo_probabilities + + # Add confidence scores if available + if self.__confidence_scores is not None: + df['confidence_scores'] = self.__confidence_scores + + return df + + def save_to_csv(self, file_path: str) -> None: + """ + Saves the dataset to a CSV file. + + Args: + file_path (str): The file path to save the dataset to. + """ + df = self.to_dataframe() + df.to_csv(file_path, index=False) diff --git a/MED3pa/med3pa/__init__.py b/MED3pa/med3pa/__init__.py new file mode 100644 index 0000000..ee74b97 --- /dev/null +++ b/MED3pa/med3pa/__init__.py @@ -0,0 +1,8 @@ +""" +The med3pa sub-package contains all the necessary classes and methods to execute the med3pa method +""" +from .experiment import * +from .mdr import * +from .models import * +from .profiles import * +from .uncertainty import * diff --git a/MED3pa/med3pa/experiment.py b/MED3pa/med3pa/experiment.py new file mode 100644 index 0000000..85f3de2 --- /dev/null +++ b/MED3pa/med3pa/experiment.py @@ -0,0 +1,410 @@ +""" +Orchestrates the execution of the med3pa method and integrates the functionality of other modules to run comprehensive experiments. +It includes ``Med3paExperiment`` to manage experiments. +""" + +from checkpointer import checkpoint +from sklearn.model_selection import train_test_split +from typing import Tuple + +from MED3pa.datasets import DatasetsManager +from MED3pa.med3pa.mdr import MDRCalculator +from MED3pa.med3pa.models import APCModel, IPCModel, MPCModel +from MED3pa.med3pa.profiles import ProfilesManager +from MED3pa.med3pa.results import Med3paResults, Med3paRecord +from MED3pa.med3pa.uncertainty import * +from MED3pa.models.base import BaseModelManager +from MED3pa.models.classification_metrics import * +from MED3pa.models.concrete_regressors import * + + +class Med3paExperiment: + """ + Class to run the MED3PA method experiment. + """ + + @staticmethod + @checkpoint(root_path="checkpoints", verbosity=False) + def run(datasets_manager: DatasetsManager, + base_model_manager: BaseModelManager = None, + uncertainty_metric: str = 'absolute_error', + ipc_type: str = 'RandomForestRegressor', + ipc_params: Dict = None, + ipc_grid_params: Dict = None, + ipc_cv: int = 4, + pretrained_ipc: str = None, + apc_params: Dict = None, + apc_grid_params: Dict = None, + apc_cv: int = 4, + pretrained_apc: str = None, + samples_ratio_min: int = 0, + samples_ratio_max: int = 50, + samples_ratio_step: int = 5, + med3pa_metrics: List[str] = None, + evaluate_models: bool = False, + use_ref_models: bool = False, + mode: str = 'mpc', + models_metrics: List[str] = None) -> Med3paResults: + + """ + Runs the MED3PA experiment on both reference and testing sets. + + Args: + datasets_manager (DatasetsManager): the datasets manager containing the dataset to use in the experiment. + base_model_manager (BaseModelManager, optional): Instance of BaseModelManager to get the base model, + by default None. + uncertainty_metric (str, optional): the uncertainty metric ysed to calculate uncertainty, + by default absolute_error. + ipc_type (str, optional): The regressor model to use for IPC, by default RandomForestRegressor. + ipc_params (dict, optional): Parameters for initializing the IPC regressor model, by default None. + ipc_grid_params (dict, optional): Grid search parameters for optimizing the IPC model, by default None. + ipc_cv (int, optional): Number of cross-validation folds for optimizing the IPC model, by default None. + pretrained_ipc (str, optional): path to a pretrained ipc, by default None. + apc_params (dict, optional): Parameters for initializing the APC regressor model, by default None. + apc_grid_params (dict, optional): Grid search parameters for optimizing the APC model, by default None. + apc_cv (int, optional): Number of cross-validation folds for optimizing the APC model, by default None. + pretrained_apc (str, optional): path to a pretrained apc, by default None. + use_ref_models (bool, optional): whether or not to use the trained IPC and APC models from the reference set + on the test set. + samples_ratio_min (int, optional): Minimum sample ratio, by default 0. + samples_ratio_max (int, optional): Maximum sample ratio, by default 50. + samples_ratio_step (int, optional): Step size for sample ratio, by default 5. + med3pa_metrics (list of str, optional): List of metrics to calculate, by default, multiple metrics included. + evaluate_models (bool, optional): Whether to evaluate the models, by default False. + mode (str): The modality of dataset, either 'ipc', 'apc', or 'mpc'. + models_metrics (list of str, optional): List of metrics for model evaluation, + by default ['MSE', 'RMSE', 'MAE']. + + Returns: + Med3paResults: the results of the MED3PA experiment on the reference set and testing set. + """ + if med3pa_metrics is None: + med3pa_metrics = ['Accuracy', 'BalancedAccuracy', 'Precision', 'Recall', 'F1Score', + 'Specificity', 'Sensitivity', 'Auc', 'LogLoss', 'Auprc', 'NPV', 'PPV', 'MCC'] + + if models_metrics is None: + models_metrics = ['MSE', 'RMSE', 'MAE'] + + results_ref = None + if datasets_manager.reference_set is not None: + print("Running MED3pa Experiment on the reference set:") + results_ref, ipc_config, apc_config = Med3paExperiment._run_by_set(datasets_manager=datasets_manager, + set='reference', + base_model_manager=base_model_manager, + uncertainty_metric=uncertainty_metric, + ipc_type=ipc_type, + ipc_params=ipc_params, + ipc_grid_params=ipc_grid_params, + ipc_cv=ipc_cv, + pretrained_ipc=pretrained_ipc, + apc_params=apc_params, + apc_grid_params=apc_grid_params, + apc_cv=apc_cv, + pretrained_apc=pretrained_apc, + samples_ratio_min=samples_ratio_min, + samples_ratio_max=samples_ratio_max, + samples_ratio_step=samples_ratio_step, + med3pa_metrics=med3pa_metrics, + evaluate_models=evaluate_models, + models_metrics=models_metrics, + mode=mode) + print("Running MED3pa Experiment on the test set:") + if use_ref_models: + if results_ref is None: + raise ValueError("use_ref_models cannot be true if reference set is None. ") + results_testing, ipc_config, apc_config = Med3paExperiment._run_by_set(datasets_manager=datasets_manager, + set='testing', + base_model_manager=base_model_manager, + uncertainty_metric=uncertainty_metric, + ipc_type=ipc_type, + ipc_params=ipc_params, + ipc_grid_params=ipc_grid_params, + ipc_cv=ipc_cv, + pretrained_ipc=pretrained_ipc, + ipc_instance=ipc_config, + apc_params=apc_params, + apc_grid_params=apc_grid_params, + apc_cv=apc_cv, + pretrained_apc=pretrained_apc, + apc_instance=apc_config, + samples_ratio_min=samples_ratio_min, + samples_ratio_max=samples_ratio_max, + samples_ratio_step=samples_ratio_step, + med3pa_metrics=med3pa_metrics, + evaluate_models=evaluate_models, + models_metrics=models_metrics, + mode=mode) + else: + results_testing, ipc_config, apc_config = Med3paExperiment._run_by_set(datasets_manager=datasets_manager, + set='testing', + base_model_manager=base_model_manager, + uncertainty_metric=uncertainty_metric, + ipc_type=ipc_type, + ipc_params=ipc_params, + ipc_grid_params=ipc_grid_params, + ipc_cv=ipc_cv, + pretrained_ipc=pretrained_ipc, + ipc_instance=None, + apc_params=apc_params, + apc_grid_params=apc_grid_params, + apc_cv=apc_cv, + pretrained_apc=pretrained_apc, + apc_instance=None, + samples_ratio_min=samples_ratio_min, + samples_ratio_max=samples_ratio_max, + samples_ratio_step=samples_ratio_step, + med3pa_metrics=med3pa_metrics, + evaluate_models=evaluate_models, + models_metrics=models_metrics, + mode=mode) + + results = Med3paResults(results_ref, results_testing) + med3pa_params = { + 'uncertainty_metric': uncertainty_metric, + 'samples_ratio_min': samples_ratio_min, + 'samples_ratio_max': samples_ratio_max, + 'samples_ratio_step': samples_ratio_step, + 'med3pa_metrics': med3pa_metrics, + 'evaluate_models': evaluate_models, + 'models_evaluation_metrics': models_metrics, + 'mode': mode, + 'ipc_model': ipc_config.get_info(), + 'apc_model': apc_config.get_info() if apc_config is not None else None, + } + experiment_config = { + 'experiment_name': "Med3paExperiment", + 'datasets': datasets_manager.get_info(), + 'base_model': base_model_manager.get_info() if base_model_manager is not None else None, + 'med3pa_params': med3pa_params + } + results.set_experiment_config(experiment_config) + results.set_models(ipc_config, apc_config) + return results + + @staticmethod + def _run_by_set(datasets_manager: DatasetsManager, + set: str = 'reference', + base_model_manager: BaseModelManager = None, + uncertainty_metric: str = 'absolute_error', + ipc_type: str = 'RandomForestRegressor', + ipc_params: Dict = None, + ipc_grid_params: Dict = None, + ipc_cv: int = 4, + pretrained_ipc: str = None, + ipc_instance: IPCModel = None, + apc_params: Dict = None, + apc_grid_params: Dict = None, + apc_cv: int = 4, + apc_instance: APCModel = None, + pretrained_apc: str = None, + samples_ratio_min: int = 0, + samples_ratio_max: int = 50, + samples_ratio_step: int = 5, + med3pa_metrics: List[str] = None, + evaluate_models: bool = False, + mode: str = 'mpc', + models_metrics: List[str] = None) -> Tuple[Med3paRecord, IPCModel, APCModel]: + + """ + Orchestrates the MED3PA experiment on one specific set of the dataset. + + Args: + datasets_manager (DatasetsManager): the datasets manager containing the dataset to use in the experiment. + base_model_manager (BaseModelManager, optional): Instance of BaseModelManager to get the base model, + by default None. + uncertainty_metric (str, optional): the uncertainty metric used to calculate uncertainty, + by default absolute_error. + ipc_type (str, optional): The regressor model to use for IPC, by default RandomForestRegressor. + ipc_params (dict, optional): Parameters for initializing the IPC regressor model, by default None. + ipc_grid_params (dict, optional): Grid search parameters for optimizing the IPC model, by default None. + ipc_cv (int, optional): Number of cross-validation folds for optimizing the IPC model, by default None. + apc_params (dict, optional): Parameters for initializing the APC regressor model, by default None. + apc_grid_params (dict, optional): Grid search parameters for optimizing the APC model, by default None. + apc_cv (int, optional): Number of cross-validation folds for optimizing the APC model, by default None. + samples_ratio_min (int, optional): Minimum sample ratio, by default 0. + samples_ratio_max (int, optional): Maximum sample ratio, by default 50. + samples_ratio_step (int, optional): Step size for sample ratio, by default 5. + med3pa_metrics (list of str, optional): List of metrics to calculate. + evaluate_models (bool, optional): Whether to evaluate the models, by default False. + mode (str): The modality of dataset, either 'ipc', 'apc', or 'mpc'. + models_metrics (list of str, optional): List of metrics for model evaluation. + + Returns: + Med3paRecord: the results of the MED3PA experiment. + IPCModel: The IPC model. + APCModel: The APC model. + """ + + # Step 1 : datasets and base model setting + # Retrieve the dataset based on the set type + if set == 'reference': + dataset = datasets_manager.get_dataset_by_type(dataset_type="reference", return_instance=True) + elif set == 'testing': + dataset = datasets_manager.get_dataset_by_type(dataset_type="testing", return_instance=True) + else: + raise ValueError("The set must be either the reference set or the testing set") + + # retrieve different dataset components needed for the experiment + x = dataset.get_observations() + y_true = dataset.get_true_labels() + predicted_probabilities = dataset.get_pseudo_probabilities() + features = datasets_manager.get_column_labels() + threshold = None + + # Initialize base model and predict probabilities if not provided + if base_model_manager is None and predicted_probabilities is None: + raise ValueError("Either the base model or the predicted probabilities should be provided!") + + if predicted_probabilities is None: + predicted_probabilities = base_model_manager.predict_proba(x)[:, 1] + threshold = base_model_manager.threshold + + dataset.set_pseudo_probs_labels(predicted_probabilities, threshold) + + # Step 2 : Mode and metrics setup + valid_modes = ['mpc', 'apc', 'ipc'] + if mode not in valid_modes: + raise ValueError(f"Invalid mode '{mode}'. The mode must be one of {valid_modes}.") + + if med3pa_metrics == []: + med3pa_metrics = ClassificationEvaluationMetrics.supported_metrics() + + # Step 3 : Calculate uncertainty values + uncertainty_calc = UncertaintyCalculator(uncertainty_metric) + uncertainty_values = uncertainty_calc.calculate_uncertainty(x, predicted_probabilities, y_true) + + # Step 4: Set up splits to evaluate the models + if evaluate_models: + x_train, x_test, uncertainty_train, uncertainty_test, y_train, y_test = train_test_split(x, + uncertainty_values, + y_true, + test_size=0.1, + random_state=42) + else: + x_train = x + uncertainty_train = uncertainty_values + + # Split the data if pretrained models are not available + if pretrained_ipc is None and pretrained_apc is None: + # Split the data in half: one half for IPC, the other half for APC + x_ipc, x_apc, uncertainty_ipc, uncertainty_apc, y_ipc, _ = train_test_split( + x_train, uncertainty_train, y_train, test_size=0.5, random_state=42) + else: + x_ipc, uncertainty_ipc = x_train, uncertainty_train + x_apc, uncertainty_apc = x_train, uncertainty_train + + results = Med3paRecord() + + # Step 5: Create and train IPCModel + if pretrained_ipc is None and ipc_instance is None: + IPC_model = IPCModel(model_name=ipc_type, params=ipc_params, pretrained_model=None) + if ipc_type == 'EnsembleRandomForestRegressor': + # Add class weight correction to train the EnsembleRandomForestRegressor + class_1_prop = np.sum(y_ipc) / len(y_ipc) + sample_weight = np.where(y_ipc == 0, 1 / (1 - class_1_prop), 1 / class_1_prop) + IPC_model.train(x_ipc, uncertainty_ipc, sample_weight=sample_weight) + else: + IPC_model.train(x_ipc, uncertainty_ipc) + print("IPC Model training complete.") + # Optimize IPC model if grid params were provided + if ipc_grid_params is not None: + if len(uncertainty_values) > 4: + # No optimization if 4 or less samples + if ipc_type == 'EnsembleRandomForestRegressor': + IPC_model.optimize(ipc_grid_params, ipc_cv, x_train, uncertainty_train, sample_weight) + else: + IPC_model.optimize(ipc_grid_params, ipc_cv, x_train, uncertainty_train) + print("IPC Model optimization complete.") + elif pretrained_ipc is not None: + IPC_model = IPCModel(model_name=ipc_type, params=ipc_params, pretrained_model=pretrained_ipc) + print("Loaded a pretrained IPC model.") + else: + IPC_model = ipc_instance + print("Used a trained IPC instance.") + + # Predict IPC values + IPC_values = IPC_model.predict(x) + print("Individualized confidence scores calculated.") + # Save the calculated confidence scores by the APCmodel + ipc_dataset = dataset.clone() + ipc_dataset.set_confidence_scores(IPC_values) + results.set_dataset(mode="ipc", dataset=ipc_dataset) + results.set_confidence_scores(IPC_values, "ipc") + metrics_by_dr = MDRCalculator.calc_metrics_by_dr(dataset=ipc_dataset, confidence_scores=IPC_values, + metrics_list=med3pa_metrics) + results.set_metrics_by_dr(metrics_by_dr) + if mode in ['mpc', 'apc']: + + # Step 6: Create and train APCModel + IPC_values = IPC_model.predict(x_apc) + if pretrained_apc is None and apc_instance is None: + APC_model = APCModel(features=features, params=apc_params) + APC_model.train(x_apc, IPC_values) + print("APC Model training complete.") + # optimize APC model if grid params were provided + if apc_grid_params is not None: + APC_model.optimize(apc_grid_params, apc_cv, x_apc, uncertainty_apc) + print("APC Model optimization complete.") + elif pretrained_apc is not None: + APC_model = APCModel(features=features, params=apc_params, pretrained_model=pretrained_apc) + APC_model.train(x_apc, IPC_values) + print("Loaded a pretrained APC model.") + else: + APC_model = apc_instance + print("Used a trainde IPC instance.") + + # Predict APC values + APC_values = APC_model.predict(x_apc) + print("Aggregated confidence scores calculated.") + # Save the tree structure created by the APCModel + tree = APC_model.treeRepresentation + results.set_tree(tree=tree) + # Save the calculated confidence scores by the APCmodel + apc_dataset = dataset.clone() + apc_dataset.set_confidence_scores(APC_model.predict(x)) + results.set_dataset(mode="apc", dataset=apc_dataset) + results.set_confidence_scores(APC_model.predict(x), "apc") + + # Step 7: Create and train MPCModel + if mode == 'mpc': + # Create and predict MPC values + MPC_model = MPCModel(IPC_values=IPC_model.predict(x), APC_values=APC_model.predict(x)) + MPC_values = MPC_model.predict() + # Save the calculated confidence scores by the MPCmodel + mpc_dataset = dataset.clone() + mpc_dataset.set_confidence_scores(MPC_values) + results.set_dataset(mode="mpc", dataset=mpc_dataset) + results.set_confidence_scores(MPC_values, "mpc") + else: + MPC_model = MPCModel(APC_values=APC_values) + MPC_values = MPC_model.predict() + mpc_dataset = dataset.clone() + mpc_dataset.set_confidence_scores(MPC_values) + + print("Mixed confidence scores calculated.") + + # Step 8: Calculate the profiles for the different samples_ratio and drs + profiles_manager = ProfilesManager(features) + for samples_ratio in range(samples_ratio_min, samples_ratio_max + 1, samples_ratio_step): + # Calculate profiles and their metrics by declaration rate + MDRCalculator.calc_profiles(profiles_manager, tree, mpc_dataset, features, MPC_values, samples_ratio) + MDRCalculator.calc_metrics_by_profiles(profiles_manager, mpc_dataset, features, MPC_values, + samples_ratio, med3pa_metrics) + results.set_profiles_manager(profiles_manager) + print("Results extracted for minimum_samples_ratio = ", samples_ratio) + + if mode in ['mpc', 'apc']: + ipc_config = IPC_model + apc_config = APC_model + if evaluate_models: + IPC_evaluation = IPC_model.evaluate(x_test, uncertainty_test, models_metrics) + APC_evaluation = APC_model.evaluate(x_test, uncertainty_test, models_metrics) + results.set_models_evaluation(IPC_evaluation, APC_evaluation) + else: + ipc_config = IPC_model + apc_config = None + if evaluate_models: + IPC_evaluation = IPC_model.evaluate(x_test, uncertainty_test, models_metrics) + results.set_models_evaluation(IPC_evaluation, None) + + return results, ipc_config, apc_config diff --git a/MED3pa/med3pa/mdr.py b/MED3pa/med3pa/mdr.py new file mode 100644 index 0000000..57517b7 --- /dev/null +++ b/MED3pa/med3pa/mdr.py @@ -0,0 +1,373 @@ +""" +Contains functionality for calculating metrics based on the predicted confidence and declaration rates (MDR). The +``MDRCalculator`` class offers methods to assess model performance across different declaration rates, +and to extract problematic profiles under specific declaration rates.""" + +import numpy as np +from typing import Dict + +from MED3pa.datasets import MaskedDataset +from MED3pa.med3pa.profiles import ProfilesManager +from MED3pa.med3pa.tree import TreeRepresentation +from MED3pa.models.classification_metrics import * + + +class MDRCalculator: + """ + Class to calculate various metrics and profiles for the MED3PA method. + """ + + @staticmethod + def _get_min_confidence_score(dr: int, confidence_scores: np.ndarray) -> float: + """ + Calculates the minimum confidence score based on the desired declaration rate. + + Args: + dr (int): Desired declaration rate as a percentage (0-100). + confidence_scores (np.ndarray): Array of confidence scores. + + Returns: + float: The minimum confidence level required to meet the desired declaration rate. + + Raises: + ValueError: If dr is not in the range 0-100. + """ + if not (0 <= dr <= 100): + raise ValueError("Declaration rate (dr) must be between 0 and 100 inclusive.") + + sorted_confidence_scores = np.sort(confidence_scores) + if dr == 0: + min_confidence_level = max(confidence_scores) + 1 # Higher than all confidence scores + elif dr == 100: + min_confidence_level = min(confidence_scores) - 1 # Lower than all confidence scores + else: + min_confidence_level = sorted_confidence_scores[int(len(sorted_confidence_scores) * (1 - dr / 100))] + return min_confidence_level + + @staticmethod + def _calculate_metrics(y_true: np.ndarray, y_pred: np.ndarray, predicted_prob: np.ndarray, metrics_list: list + ) -> dict: + """ + Calculate a variety of metrics based on the true labels, predicted labels, and predicted probabilities. + + Args: + y_true (np.ndarray): Array of true labels. + y_pred (np.ndarray): Array of predicted labels. + predicted_prob (np.ndarray): Array of predicted probabilities. + metrics_list (list): List of metric names to be calculated. + + Returns: + dict: A dictionary where keys are metric names and values are the calculated metric values. + + """ + metrics_dict = {} + for metric_name in metrics_list: + metric_function = ClassificationEvaluationMetrics.get_metric(metric_name) + if metric_function: + if metric_name in {'Auc', 'Auprc', 'Logloss'}: + metrics_dict[metric_name] = metric_function(y_true, predicted_prob) + else: + metrics_dict[metric_name] = metric_function(y_true, y_pred) + else: + raise ValueError(f"Error: The metric '{metric_name}' is not supported.") + return metrics_dict + + @staticmethod + def _list_difference_by_key(list1: List[Dict], list2: List[Dict], key='node_id') -> List[Dict]: + """ + Calculate the difference between two lists of Profile instances based on a specific key. + + Args: + list1 (List[Dict]): First list of Profile instances. + list2 (List[Dict]): Second list of Profile instances. + key (str): Key to compare for differences (default is 'node_id'). + + Returns: + List[Dict]: A list containing elements from list1 that do not appear in list2 based on the specified key. + """ + set1 = {d[key] for d in list1 if key in d} + set2 = {d[key] for d in list2 if key in d} + unique_to_list1 = set1 - set2 + return [d for d in list1 if d[key] in unique_to_list1] + + @staticmethod + def _filter_by_profile(dataset: MaskedDataset, path: List, features: list, min_confidence_level: float = None): + """ + Filters datasets based on specific profile conditions described by a path. + + Args: + dataset (MaskedDataset): The dataset to filter. + features (list): The list of features to filter on. + path (list): Conditions describing the profile path. + min_confidence_level(float): Possibility to filter according a minimum confidence score if specified. + + Returns: + tuple: Filtered datasets including observations, true labels, predicted probabilities, predicted labels, and mpc values. + """ + + # retrieve different dataset components to calculate the metrics + x = dataset.get_observations() + y_true = dataset.get_true_labels() + y_pred = dataset.get_pseudo_labels() + predicted_prob = dataset.get_pseudo_probabilities() + confidence_scores = dataset.get_confidence_scores() + + # Start with a mask that selects all rows + mask = np.ones(len(x), dtype=bool) + + for condition in path: + if condition == '*': + continue # Skip the root node indicator + + # Parse the condition string + column_name, operator, value_str = condition.split(' ') + column_index = features.index(column_name) # Map feature name to index + try: + value = float(value_str) + except ValueError: + # If conversion fails, the string is not a number. Handle it appropriately. + value = value_str # If it's supposed to be a string, leave it as string + + # Apply the condition to update the mask + if operator == '>': + mask &= x[:, column_index] > value + elif operator == '<': + mask &= x[:, column_index] < value + elif operator == '>=': + mask &= x[:, column_index] >= value + elif operator == '<=': + mask &= x[:, column_index] <= value + elif operator == '==': + mask &= x[:, column_index] == value + elif operator == '!=': + mask &= x[:, column_index] != value + else: + raise ValueError(f"Unsupported operator '{operator}' in condition '{condition}'.") + + # Filter the data according to the path mask + filtered_x = x[mask] + filtered_y_true = y_true[mask] + if predicted_prob is not None: + filtered_prob = predicted_prob[mask] + else: + filtered_prob = None + + if y_pred is not None: + filtered_y_pred = y_pred[mask] + else: + filtered_y_pred = None + + if confidence_scores is not None: # None for testing and reference sets + filtered_confidence_scores = confidence_scores[mask] + else: + filtered_confidence_scores = None + + # filter once again according to the min_confidence_level if specified + if min_confidence_level is not None: + filtered_x = filtered_x[filtered_confidence_scores >= min_confidence_level] + filtered_y_true = filtered_y_true[filtered_confidence_scores >= min_confidence_level] + filtered_prob = filtered_prob[ + filtered_confidence_scores >= min_confidence_level] if predicted_prob is not None else None + filtered_y_pred = filtered_y_pred[ + filtered_confidence_scores >= min_confidence_level] if y_pred is not None else None + filtered_confidence_scores = filtered_confidence_scores[ + filtered_confidence_scores >= min_confidence_level] if confidence_scores is not None else None + + return filtered_x, filtered_y_true, filtered_prob, filtered_y_pred, filtered_confidence_scores + + @staticmethod + def calc_metrics_by_dr(dataset: MaskedDataset, confidence_scores: np.ndarray, metrics_list: list + ) -> Dict[int, Dict]: + """ + Calculate metrics by declaration rates (DR), evaluating model performance at various thresholds of predicted + accuracies. + + Args: + dataset (MaskedDataset): The dataset to filter. + confidence_scores (np.ndarray): the confidence scores used for filtering. + metrics_list (list): List of metric names to be calculated (e.g., 'AUC', 'Accuracy'). + + Returns: + Dict: A dictionary containing metrics computed for each declaration rate from 100% to 0%, including metrics + and population percentage. + """ + + # retrieve different dataset components to calculate the metrics + y_true = dataset.get_true_labels() + y_pred = dataset.get_pseudo_labels() + predicted_prob = dataset.get_pseudo_probabilities() + + # initialize the dictionaries used for results storage + metrics_by_dr = {} # global dictionary containing all the declaration rates and their corresponding metrics + last_dr_values = {} # used to save last dr calculated metrics + last_min_confidence_level = -1 + + # for each declaration rate + for dr in range(100, -1, -1): + # calculate the minimum confidence level + min_confidence_level = MDRCalculator._get_min_confidence_score(dr, confidence_scores) + + # if the current confidence level is different from the last one + if last_min_confidence_level != min_confidence_level: + + # update the last confidence level + last_min_confidence_level = min_confidence_level + + # save the confidence level in the dict of the current dr + dr_values = {'min_confidence_level': min_confidence_level} + + # defines the mask to keep only data with higher min_confidence levels + confidence_mask = confidence_scores >= min_confidence_level + + # save the left population percentage + dr_values['population_percentage'] = sum(confidence_mask) / len(confidence_scores) + dr_values['mean_confidence_level'] = np.mean(confidence_scores[confidence_mask]) if confidence_scores[ + confidence_mask].size > 0 else None + dr_values['Positive%'] = np.sum(y_true[confidence_mask]) / len(y_true[confidence_mask]) * 100 if \ + len(y_true[confidence_mask]) > 0 else None + # Calculate the metrics for the current DR + metrics_dict = MDRCalculator._calculate_metrics(y_true[confidence_mask], y_pred[confidence_mask], + predicted_prob[confidence_mask], metrics_list) + + # save the calculated metrics + dr_values['metrics'] = metrics_dict + + # update the last dr dictionary metrics + last_dr_values = dr_values + + # save it in the global dictionary + metrics_by_dr[dr] = dr_values + + # if the min_confidence level is the same, use the last dr results + else: + metrics_by_dr[dr] = last_dr_values + + # return the global dictionary + return metrics_by_dr + + @staticmethod + def calc_profiles(profiles_manager: ProfilesManager, tree: TreeRepresentation, dataset: MaskedDataset, + features: list, confidence_scores: np.ndarray, min_samples_ratio: int) -> Dict[int, float]: + """ + Calculates profiles for different declaration rates and minimum sample ratios. This method assesses how profiles + change across different confidence levels derived from predicted accuracies. + + Args: + profiles_manager (ProfilesManager): Manager for storing and retrieving profile information. + tree (TreeRepresentation): Tree structure from which profiles are derived. + dataset (MaskedDataset): The dataset to filter. + features (list): the list of features to filter on. + confidence_scores (np.ndarray): Array of predicted accuracy values used for thresholding profiles. + min_samples_ratio (int): Minimum sample ratio to consider for including a profile. + + Returns: + Dict[int, float]: A dictionary with declaration rates as keys and their corresponding minimum confidence levels as values. + """ + + # Initialization of different variables + all_nodes = tree.get_all_nodes() # Retrieve all nodes from the tree + last_profiles = all_nodes # Initialize last profiles as all nodes + lost_profiles_all = [] # Saves lost profiles + last_min_confidence_level = -1 # Last min confidence level + min_confidence_levels_dict = {} # Saves the min_confidence_level thresholds + + # Go through all declaration rates + for dr in range(100, -1, -1): + + # Calculate the min confidence level for this dr + min_confidence_level = MDRCalculator._get_min_confidence_score(dr, confidence_scores) + min_confidence_levels_dict[dr] = min_confidence_level + + # If the current confidence level is different from the last one + if min_confidence_level != last_min_confidence_level: + + # Update the last min confidence level + last_min_confidence_level = min_confidence_level + # Saves the profiles of this dr + profiles_current = [] + + # Calculate mean_ca and samples_ratio for all nodes to see if this node is eligible as a profile + for node in all_nodes: + # filter the data that belongs to this node, and filter according to min_confidence_level threshold + _, _, _, _, filtered_confidence_scores = MDRCalculator._filter_by_profile( + dataset, node['path'], features=features, min_confidence_level=min_confidence_level) + + # calculate the samples_ratio (pop%) and mean_confidence_level of this node + if len(filtered_confidence_scores) > 0: + samples_ratio = len(filtered_confidence_scores) / len(confidence_scores) * 100 + mean_confidence = np.mean( + filtered_confidence_scores) if filtered_confidence_scores.size > 0 else 0 + # if the calculated samples_ratio and mean_confidence meet the conditions, keep this node + if samples_ratio >= min_samples_ratio and mean_confidence >= min_confidence_level: + profiles_current.append(node) + + # If the last profiles are different from current profiles + if len(last_profiles) != len(profiles_current): + # Extract lost profiles + lost_profiles = MDRCalculator._list_difference_by_key(last_profiles, profiles_current) + lost_profiles_all.extend(lost_profiles) + + # Update the last profiles + last_profiles = profiles_current + + # If the current min_confidence is same as the last one, use the last dr results + profiles_current_ins = profiles_manager.transform_to_profiles(profiles_current) + lost_profiles_current_ins = profiles_manager.transform_to_profiles(lost_profiles_all) + profiles_manager.insert_profiles(dr, min_samples_ratio, profiles_current_ins) + profiles_manager.insert_lost_profiles(dr, min_samples_ratio, lost_profiles_current_ins) + + return min_confidence_levels_dict + + @staticmethod + def calc_metrics_by_profiles(profiles_manager, dataset: MaskedDataset, features: List, + confidence_scores: np.ndarray, min_samples_ratio: int, metrics_list: List) -> None: + """ + Calculates various metrics for different profiles and declaration rates based on provided datasets. + + Args: + profiles_manager (ProfilesManager): Manager handling profiles. + dataset (MaskedDataset): the dataset to use. + features (List): The list of features to filter on. + confidence_scores (np.ndarray): Array of predicted accuracy values used for thresholding profiles. + min_samples_ratio (int): Minimum sample ratio to consider for including a profile. + metrics_list (List): List of metrics to calculate. + + """ + # retrieve different dataset components to calculate the metrics + all_y_true = dataset.get_true_labels() + all_confidence_scores = confidence_scores + + dr_dict = profiles_manager.profiles_records.get(min_samples_ratio) + + # go through all profiles, for each ratio and for each dr + if dr_dict is not None: + # for each dr and its profiles stored in the ratio + for dr, profiles in dr_dict.items(): + # calculate the min_confidence level + min_confidence_level = MDRCalculator._get_min_confidence_score(dr, all_confidence_scores) + + # go through each profile in the profile list + for profile in profiles: + x, y_true, pred_prob, y_pred, confidence_scores = MDRCalculator._filter_by_profile(dataset, + profile.path, + features) + # calculate the metrics for this profile + confidence_mask = confidence_scores >= min_confidence_level + metrics_dict = MDRCalculator._calculate_metrics(y_true=y_true[confidence_mask], + y_pred=y_pred[confidence_mask], + predicted_prob=pred_prob[confidence_mask], + metrics_list=metrics_list) + info_dict = {} + # the remaining node population at the current dr compared to node population at dr = 100 + info_dict['Node%'] = len(y_true[confidence_mask]) * 100 / len(y_true) + # the remaining node population at the current dr compared to the whole population at dr = 100 + info_dict['Population%'] = len(y_true[confidence_mask]) * 100 / len(all_y_true) + # the mean confidence level for this profile at this dr + info_dict['Mean confidence level'] = np.mean(confidence_scores[confidence_mask]) * 100 if \ + confidence_scores[confidence_mask].size > 0 else None + # the positive class percentage in this profile at this dr + info_dict['Positive%'] = np.sum(y_true[confidence_mask]) / len(y_true[confidence_mask]) * 100 if \ + len(y_true[confidence_mask]) > 0 else None + # update the calculated metrics in the profile + profile.update_metrics_results(metrics_dict) + profile.update_node_information(info_dict) diff --git a/MED3pa/med3pa/models.py b/MED3pa/med3pa/models.py new file mode 100644 index 0000000..254f722 --- /dev/null +++ b/MED3pa/med3pa/models.py @@ -0,0 +1,382 @@ +"""Defines the models used within the MED3pa framework. It includes classes for Individualized Predictive Confidence +(IPC) models that predict uncertainty at an individual level, where the regressor type can be specified by the user. +Additionally, it includes Aggregated Predictive Confidence (APC) models that predict uncertainty for groups of +similar data points, and Mixed Predictive Confidence (MPC) models that combine the predictions from IPC and APC +models. +""" + +import json +import numpy as np +import pandas as pd +import pickle +from sklearn.model_selection import GridSearchCV +from sklearn.ensemble import RandomForestRegressor +from typing import Any, Dict, List, Optional + +from MED3pa.med3pa.tree import TreeRepresentation +from MED3pa.models.concrete_regressors import (DecisionTreeRegressorModel, RandomForestRegressorModel, + EnsembleRandomForestRegressorModel) +from MED3pa.models.data_strategies import ToDataframesStrategy +from MED3pa.models import rfr_params, dtr_params + + +class AbstractUncertaintyEstimator: + default_params = {'random_state': 54288} + + supported_regressors_mapping = { + 'RandomForestRegressor': RandomForestRegressorModel, + 'EnsembleRandomForestRegressor': EnsembleRandomForestRegressorModel, + 'DecisionTreeRegressor': DecisionTreeRegressorModel + } + + supported_regressors_params = { + 'RandomForestRegressor': { + 'params': rfr_params.rfr_params, + 'grid_params': rfr_params.rfr_gridsearch_params + }, + 'EnsembleRandomForestRegressor': { + 'params': rfr_params.rfr_params, + 'grid_params': rfr_params.rfr_gridsearch_params + } + } + + def __init__(self, model_name: str, + params: Optional[Dict[str, Any]] = None, + pretrained_model: Optional[str] = None): + """ + Initializes the AbstractUncertaintyEstimator class instance. + + Args: + model_name (str): Name of the model. + params (Optional[Dict[str, Any]]): Parameters to initialize the regression model, default is None. + pretrained_model (Optional[str]): Path to a pretrained model, default is None. + """ + if model_name not in self.supported_regressors_mapping: + raise ValueError( + f"Unsupported model name: {model_name}. Supported models are: " + f"{list(self.supported_regressors_mapping)}") + + model_class = self.supported_regressors_mapping[model_name] + + if params is None: + params = self.default_params.copy() + elif 'random_state' not in params: + params['random_state'] = self.default_params['random_state'] + + self.model = model_class(params) + self.params = params + self.grid_search_params = {} + self.optimized = False + self.pretrained = False + self.model_name = model_name + + if pretrained_model is not None: + self.load_model(pretrained_model) + + def evaluate(self, X: np.ndarray, y: np.ndarray, eval_metrics: List[str], print_results: bool = False + ) -> Dict[str, float]: + """ + Evaluates the model using specified metrics. + + Args: + X (np.ndarray): observations for evaluation. + y (np.ndarray): True labels for evaluation. + eval_metrics (List[str]): Metrics to use for evaluation. + print_results (bool): Whether to print the evaluation results. + + Returns: + Dict[str, float]: A dictionary with metric names and their evaluated scores. + """ + evaluation_results = self.model.evaluate(X, y, eval_metrics, print_results) + return evaluation_results + + def get_info(self) -> Dict[str, Any]: + """ + Returns information about the AbstractUncertaintyEstimator instance. + + Returns: + Dict[str, Any]: A dictionary containing the model name, parameters, whether the model was optimized, and other relevant details. + """ + return { + 'model_name': self.model_name, + 'params': self.params if not self.pretrained else {}, + 'optimized': self.optimized, + 'grid_search_params': self.grid_search_params, + 'pretrained': self.pretrained + } + + def save_model(self, file_path: str) -> None: + """ + Saves the trained model to a pickle file. + + Args: + file_path (str): The path to the file where the model will be saved. + """ + with open(file_path, 'wb') as file: + pickle.dump(self.model, file) + + def load_model(self, file_path: str) -> None: + """ + Loads a pre-trained model from a pickle file. + + Args: + file_path (str): The path to the pickle file. + """ + with open(file_path, 'rb') as file: + loaded_model = pickle.load(file) + + if not isinstance(loaded_model, self.supported_regressors_mapping[self.model_name]): + raise TypeError(f"The loaded model type does not match the specified model type: {self.model_name}") + + self.model = loaded_model + self.pretrained = True + + +class IPCModel(AbstractUncertaintyEstimator): + """ + IPCModel class used to predict the Individualized predicted confidence. ie, the base model confidence for each data + point. + """ + default_params = {'random_state': 54288} + + def __init__(self, model_name: str = 'RandomForestRegressor', params: Optional[Dict[str, Any]] = None, + pretrained_model: Optional[str] = None) -> None: + """ + Initializes the IPCModel with a regression model class name and optional parameters. + + Args: + model_name (str): The name of the regression model class to use, default is 'RandomForestRegressor'. + Allowed values are in IPCModel.supported_regressors_mapping. + params (Optional[Dict[str, Any]]): Parameters to initialize the regression model, default is None. + pretrained_model (Optional[str]): Path to a pretrained regression model, serving as ipc model, + default is None. + """ + super().__init__(model_name=model_name, params=params, pretrained_model=pretrained_model) + + @classmethod + def supported_ipc_models(cls) -> List: + """ + Returns a list of supported IPC models. + + Returns: + list: A list of supported regression model names. + """ + return list(AbstractUncertaintyEstimator.supported_regressors_mapping) + + @classmethod + def supported_models_params(cls) -> Dict[str, Dict[str, Any]]: + """ + Returns a dictionary containing the supported models and their parameters and grid search parameters. + + Returns: + Dict[str, Dict[str, Any]]: A dictionary with model names as keys and another dictionary as value containing + 'params' and 'grid_search_params' for each model. + """ + return AbstractUncertaintyEstimator.supported_regressors_params + + def optimize(self, param_grid: dict, cv: int, x: np.ndarray, confidence_score: np.ndarray, + sample_weight: np.ndarray = None) -> None: + """ + Optimizes the model parameters using GridSearchCV. + + Args: + param_grid (Dict[str, Any]): The parameter grid to explore. + cv (int): The number of cross-validation folds. + confidence_score (np.ndarray): The confidence scores to train the IPCModel. + x (np.ndarray): Training data. + sample_weight (Optional[np.ndarray]): Weights for the training samples. + """ + if sample_weight is None: + sample_weight = np.full(x.shape[0], 1) + + grid_search = GridSearchCV(estimator=self.model.model, param_grid=param_grid, cv=cv, n_jobs=-1, verbose=0) + grid_search.fit(x, confidence_score, sample_weight=sample_weight) + + self.model.set_model(grid_search.best_estimator_) + self.model.update_params(grid_search.best_params_) + self.params.update(grid_search.best_params_) + self.grid_search_params = param_grid + self.optimized = True + + def train(self, x: np.ndarray, confidence_score: np.ndarray, **params) -> None: + """ + Trains the model on the provided training data and error probabilities. + + Args: + x (np.ndarray): Feature matrix for training. + confidence_score (np.ndarray): The confidence scores corresponding to each training instance. + """ + self.model.train(x, confidence_score, **params) + + def predict(self, x: np.ndarray) -> np.ndarray: + """ + Predicts error probabilities for the given input observations using the trained model. + + Args: + x (np.ndarray): Feature matrix for which to predict error probabilities. + + Returns: + np.ndarray: Predicted error probabilities. + """ + return self.model.predict(x) + + +class APCModel(AbstractUncertaintyEstimator): + """ + APCModel class used to predict the Aggregated predicted confidence. ie, the base model confidence for a group of + similar data points. + """ + default_params = {'max_depth': 3, 'min_samples_leaf': 1, 'random_state': 54288} + + supported_params = { + 'DecisionTreeRegressor': { + 'params': dtr_params.dtr_params, + 'grid_params': dtr_params.dtr_gridsearch_params + } + } + + def __init__(self, features: List[str], params: Optional[Dict[str, Any]] = None, + tree_file_path: Optional[str] = None, pretrained_model: Optional[str] = None, + model_name: str = "DecisionTreeRegressor") -> None: + """ + Initializes the APCModel with the necessary components to perform tree-based regression and to build a tree + representation. + + Args: + features (List[str]): List of features used in the model. + params (Optional[Dict[str, Any]]): Parameters to initialize the regression model, default is settings for + a basic decision tree. + tree_file_path (Optional[str]): Path to the saved tree JSON file, default is None. + pretrained_model (Optional[str]): Path to a pretrained DecisionTree model, serving as apc model, + default is None. + model_name (str): Name of the model, default is "DecisionTreeRegressor". + """ + super().__init__(model_name=model_name, params=params, pretrained_model=pretrained_model) + + self.treeRepresentation = TreeRepresentation(features=features) + self.dataPreparationStrategy = ToDataframesStrategy() + self.features = features + self.loaded_tree = None + + if tree_file_path: + self.load_tree(tree_file_path) + + def load_tree(self, file_path: str) -> None: + """ + Loads the tree structure from a JSON file and initializes the tree representation. + + Args: + file_path (str): The file path from which the tree structure will be loaded. + """ + with open(file_path, 'r') as file: + tree_dict = json.load(file) + + self.loaded_tree = tree_dict + + @classmethod + def supported_models_params(cls) -> Dict[str, Dict[str, Any]]: + """ + Returns a dictionary containing the supported models and their parameters and grid search parameters. + + Returns: + Dict[str, Dict[str, Any]]: A dictionary with model names as keys and another dictionary as value containing + 'params' and 'grid_search_params' for each model. + """ + return cls.supported_params + + def train(self, x: np.ndarray, error_prob: np.ndarray | pd.Series) -> None: + """ + Trains the model using the provided data and error probabilities and builds the tree representation. + + Args: + x (np.ndarray): Feature matrix for training. + error_prob (np.ndarray | pd.Series): Error probabilities corresponding to each training instance. + """ + if not self.pretrained: + self.model.train(x, error_prob) + df_X, df_y, df_w = self.dataPreparationStrategy.execute(column_labels=self.features, observations=x, + labels=error_prob) + self.treeRepresentation.head = self.treeRepresentation.build_tree(self.model, df_X, error_prob, 0) + + def optimize(self, param_grid: dict, cv: int, x: np.ndarray, confidence_score: np.ndarray, + sample_weight: np.ndarray = None) -> None: + """ + Optimizes the model parameters using GridSearchCV. + + Args: + param_grid (Dict[str, Any]): The parameter grid to explore. + cv (int): The number of cross-validation folds. + x (np.ndarray): Training data. + confidence_score (np.ndarray): The confidence scores to train the APCModel. + sample_weight (Optional[np.ndarray]): Weights for the training samples. + """ + if sample_weight is None: + sample_weight = np.full(x.shape[0], 1) + grid_search = GridSearchCV(estimator=self.model.model, param_grid=param_grid, cv=cv, n_jobs=-1, verbose=0) + grid_search.fit(x, confidence_score, sample_weight=sample_weight) + self.model.set_model(grid_search.best_estimator_) + self.model.update_params(grid_search.best_params_) + self.params.update(grid_search.best_params_) + self.grid_search_params = param_grid + df_X, df_y, df_w = self.dataPreparationStrategy.execute(column_labels=self.features, observations=x, + labels=confidence_score) + self.treeRepresentation.build_tree(self.model, df_X, confidence_score, node_id=0) + self.optimized = True + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Predicts error probabilities using the tree representation for the given input observations. + + Args: + X (np.ndarray): Feature matrix for which to predict error probabilities. + + Returns: + np.ndarray: Predicted error probabilities based on the aggregated confidence levels. + """ + df_X, _, _ = self.dataPreparationStrategy.execute(column_labels=self.features, observations=X, labels=None) + predictions = [] + + for index, row in df_X.iterrows(): + if self.treeRepresentation.head is not None: + prediction = self.treeRepresentation.head.assign_node(row) + predictions.append(prediction) + else: + raise ValueError("The Tree Representation has not been initialized, try fitting the APCModel first.") + + return np.array(predictions) + + +class MPCModel: + """ + MPCModel class used to predict the Mixed predicted confidence. ie, the minimum between the APC and IPC values. + """ + + def __init__(self, IPC_values: np.ndarray = None, APC_values: np.ndarray = None) -> None: + """ + Initializes the MPCModel with IPC and APC values. + + Args: + IPC_values (np.ndarray): IPC values. + APC_values (np.ndarray): APC values. + """ + self.IPC_values = IPC_values + self.APC_values = APC_values + + def predict(self) -> np.ndarray: + """ + Combines IPC and APC values to predict MPC values. + + Returns: + np.ndarray: Combined MPC values. + """ + if self.APC_values is None and self.IPC_values is None: + raise ValueError("Both APC values and IPC values are not set!") + + if self.APC_values is None: + MPC_values = self.IPC_values + elif self.IPC_values is None: + MPC_values = self.APC_values + else: + MPC_values = np.minimum(self.IPC_values, self.APC_values) + + return MPC_values diff --git a/MED3pa/med3pa/profiles.py b/MED3pa/med3pa/profiles.py new file mode 100644 index 0000000..1e85a1f --- /dev/null +++ b/MED3pa/med3pa/profiles.py @@ -0,0 +1,167 @@ +"""Handles the management and storage of profiles derived from the tree representation. It defines a ``Profile`` +class to encapsulate metrics and values associated with a specific node in the tree and a ``ProfilesManager`` class +to manage collections of profiles and track lost profiles during analysis.""" + +from typing import Dict, List + + +class Profile: + """ + Represents a profile containing metrics and values associated with a specific node. + """ + + def __init__(self, node_id: int, path: List[str]) -> None: + """ + Initializes a Profile instance with node details and associated metrics. + + Args: + node_id (int): The identifier for the node associated with this profile. + path (List[str]): A List OF String containing representation of the path to this node within the tree. + """ + self.node_id = node_id + self.path = path + self.mean_value = None + self.metrics = None + self.node_information = None + + def to_dict(self, save_all: bool = True) -> Dict: + """ + Converts the Profile instance into a dictionary format suitable for serialization. + + Returns: + dict: A dictionary representation of the Profile instance including the node ID, path, mean value, + metrics. + """ + if save_all: + return { + 'id': self.node_id, + 'path': self.path, + 'metrics': self.metrics, + 'node information': self.node_information + } + else: + return { + 'id': self.node_id, + 'path': self.path, + } + + def update_metrics_results(self, metrics: dict) -> None: + """ + Updates the metrics associated with this profile. + + Args: + metrics (dict): The results to be added to the profile. + """ + self.metrics = metrics + + def update_node_information(self, info: dict) -> None: + """ + Updates the information associated with this profile. + + Args: + info (dict): The updated node information. + """ + self.node_information = info + + +class ProfilesManager: + """ + Manages the records of profiles and lost profiles based on declaration rates and minimal samples ratio. + """ + + def __init__(self, features: List[str]) -> None: + """ + Initializes the ProfilesManager with a set of features. + + Args: + features (List[str]): A list of features considered in the profiles. + """ + self.profiles_records = {} + self.lost_profiles_records = {} + self.features = features + + def insert_profiles(self, dr: int, min_samples_ratio: int, profiles: List[Profile]) -> None: + """ + Inserts profiles into the records under a specific dr value and minimum sample ratio. + + Args: + dr (int): Desired declaration rate as a percentage. + min_samples_ratio (int): Minimum samples ratio. + profiles (List[Profile]): The profiles to insert. + """ + if min_samples_ratio not in self.profiles_records: + self.profiles_records[min_samples_ratio] = {} + self.profiles_records[min_samples_ratio][dr] = profiles.copy() + + def insert_lost_profiles(self, dr: int, min_samples_ratio: int, profiles: List[Profile]) -> None: + """ + Inserts lost profiles into the records under a specific dr value and minimum sample ratio. + + Args: + dr (int): Desired declaration rate as a percentage. + min_samples_ratio (int): Minimum samples ratio. + profiles (List[Profile]): The profiles to insert. + """ + if min_samples_ratio not in self.lost_profiles_records: + self.lost_profiles_records[min_samples_ratio] = {} + self.lost_profiles_records[min_samples_ratio][dr] = profiles.copy() + + def get_profiles(self, min_samples_ratio: int = None, dr: int = None) -> Dict: + """ + Retrieves profiles based on the specified minimum sample ratio and dr value. + + Args: + dr (int): Desired declaration rate as a percentage. + min_samples_ratio (int): Minimum samples ratio. + + Returns: + Dict[Profile]: Profiles with the specified minimum sample ratio. + """ + + if min_samples_ratio is not None: + if dr is not None: + if min_samples_ratio not in self.profiles_records: + raise ValueError("The profiles for this min_samples_ratio have not been calculated yet!") + return self.profiles_records[min_samples_ratio][dr] + return self.profiles_records[min_samples_ratio] + return self.profiles_records + + def get_lost_profiles(self, min_samples_ratio: int = None, dr: int = None) -> Dict: + """ + Retrieves lost profiles based on the specified minimum sample ratio and dr value. + + Args: + min_samples_ratio (int): Minimum samples ratio. + dr (int): Desired declaration rate as a percentage. + + Returns: + Dict: Lost profiles with the specified minimum sample ratio and dr value. + """ + if min_samples_ratio is not None: + if dr is not None: + if min_samples_ratio not in self.lost_profiles_records: + raise ValueError("The lost profiles for this min_samples_ratio have not been calculated yet!") + return self.lost_profiles_records[min_samples_ratio][dr] + return self.lost_profiles_records[min_samples_ratio] + return self.lost_profiles_records + + @staticmethod + def transform_to_profiles(profiles_list: List[dict], to_dict: bool = False) -> List[Profile | Dict]: + """ + Transforms a list of profile data into instances of the Profile class or dictionaries. + + Args: + profiles_list (List[dict]): List of profiles data. + to_dict (bool, optional): If True, transforms profiles to dictionaries. Defaults to False. + + Returns: + List[Union[dict, Profile]]: List of transformed profiles. + """ + profiles = [] + for profile in profiles_list: + if to_dict: + profile_ins = Profile(profile['node_id'], profile['path']).to_dict() + else: + profile_ins = Profile(profile['node_id'], profile['path']) + profiles.append(profile_ins) + return profiles diff --git a/MED3pa/med3pa/results.py b/MED3pa/med3pa/results.py new file mode 100644 index 0000000..5bce3bd --- /dev/null +++ b/MED3pa/med3pa/results.py @@ -0,0 +1,429 @@ +""" +This module stores and manages the results of the MED3pa experiments. +It includes the ``Med3paRecord`` class, responsible for storing and managing results for each set, +and the ``Med3paResult`` class, responsible for storing and managing all results of the experiment. +""" + +import datetime +import json +import numpy as np +import os +from typing import Any, Dict, TextIO + +from MED3pa.datasets import MaskedDataset +from MED3pa.med3pa.models import APCModel, IPCModel +from MED3pa.med3pa.profiles import Profile, ProfilesManager +from MED3pa.med3pa.tree import TreeRepresentation + + +def to_serializable(obj: Any, additional_arg: Any = None) -> Any: + """ + Convert an object to a JSON-serializable format. + Args: + obj (Any): The object to convert. + additional_arg (Any): Additional arguments to serialize + Returns: + Any: The JSON-serializable representation of the object. + """ + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, (np.integer, np.floating)): + return obj.item() + if isinstance(obj, Profile): + if additional_arg is not None: + return obj.to_dict(additional_arg) + else: + return obj.to_dict() + if isinstance(obj, dict): + return {k: to_serializable(v, additional_arg) for k, v in obj.items()} + if isinstance(obj, list): + return [to_serializable(v, additional_arg) for v in obj] + return obj + + +class Med3paRecord: + """ + Class to store and manage results from the MED3PA method on one set. + """ + + def __init__(self) -> None: + self.metrics_by_dr: Dict[int, Dict] = {} + self.models_evaluation: Dict[str, Dict] = {} + self.profiles_manager: ProfilesManager = None + self.datasets: Dict[str, MaskedDataset] = {} + self.experiment_config = {} + self.tree = None + self.ipc_scores = None + self.apc_scores = None + self.mpc_scores = None + + def set_metrics_by_dr(self, metrics_by_dr: Dict) -> None: + """ + Sets the calculated metrics by declaration rate. + Args: + metrics_by_dr (Dict): Dictionary of metrics by declaration rate. + """ + self.metrics_by_dr = metrics_by_dr + + def set_profiles_manager(self, profile_manager: ProfilesManager) -> None: + """ + Sets the profile manager for this Med3paResults instance. + + Args: + profile_manager (ProfilesManager): The ProfileManager instance. + """ + self.profiles_manager = profile_manager + + def set_models_evaluation(self, ipc_evaluation: Dict, apc_evaluation: Dict = None) -> None: + """ + Sets models evaluation metrics. + Args: + ipc_evaluation (Dict): Evaluation metrics for IPC model. + apc_evaluation (Dict): Evaluation metrics for APC model. + """ + self.models_evaluation['IPC_evaluation'] = ipc_evaluation + + if apc_evaluation is not None: + self.models_evaluation['APC_evaluation'] = apc_evaluation + + def set_tree(self, tree: TreeRepresentation) -> None: + """ + Sets the constructed tree. + """ + self.tree = tree + + def set_dataset(self, mode: str, dataset: MaskedDataset) -> None: + """ + Saves the dataset for a given sample ratio. + Args: + mode (str): The modality of dataset, either 'ipc', 'apc', or 'mpc'. + dataset (MaskedDataset): The MaskedDataset instance. + """ + + self.datasets[mode] = dataset + + def save(self, file_path: str) -> None: + """ + Saves the experiment results. + Args: + file_path (str): The file path to save the JSON files. + """ + # Ensure the main directory exists + os.makedirs(file_path, exist_ok=True) + + metrics_file_path = os.path.join(file_path, 'metrics_dr.json') + with open(metrics_file_path, 'w') as file: + json.dump(self.metrics_by_dr, file, default=to_serializable, indent=4) + + if self.profiles_manager is not None: + profiles_file_path = os.path.join(file_path, 'profiles.json') + with open(profiles_file_path, 'w') as file: + json.dump(self.profiles_manager.get_profiles(), file, default=to_serializable, indent=4) + + lost_profiles_file_path = os.path.join(file_path, 'lost_profiles.json') + with open(lost_profiles_file_path, 'w') as file: + json.dump(self.profiles_manager.get_lost_profiles(), file, + default=lambda x: to_serializable(x, additional_arg=False), indent=4) + + if self.models_evaluation is not None: + models_evaluation_file_path = os.path.join(file_path, 'models_evaluation.json') + with open(models_evaluation_file_path, 'w') as file: + json.dump(self.models_evaluation, file, default=to_serializable, indent=4) + + for mode, dataset in self.datasets.items(): + dataset_path = os.path.join(file_path, f'dataset_{mode}.csv') + dataset.save_to_csv(dataset_path) + + if self.tree is not None: + tree_path = os.path.join(file_path, 'tree.json') + self.tree.save_tree(tree_path) + + def save_to_dict(self) -> dict: + """ + Collects the experiment results in a dictionary. + Returns: + dict: A dictionary containing all the saved elements. + """ + result = {} + + # Store profiles if available + if self.profiles_manager is not None: + result['lost_profiles'] = to_serializable(self.profiles_manager.get_lost_profiles(), False) + result['profiles'] = to_serializable(self.profiles_manager.get_profiles()) + + # Store metrics by declaration rate (DR) + result['metrics_dr'] = self.metrics_by_dr + + # Store models evaluation if available + if self.models_evaluation is not None: + result['models_evaluation'] = self.models_evaluation + + # Store tree structure if available + if self.tree is not None: + result['tree'] = self.tree.to_dict() + + return result + + def get_profiles_manager(self) -> ProfilesManager: + """ + Retrieves the profiles manager for this Med3paResults instance + """ + return self.profiles_manager + + def set_confidence_scores(self, scores: np.ndarray, mode: str) -> None: + """ + Sets the confidence scores for this Med3paResults. + + Args: + scores: The confidence scores for this Med3paResults. + mode: The modality of model for these confidence scores. Either 'ipc', 'apc' or 'mpc'. + """ + if mode == 'ipc': + self.ipc_scores = scores + elif mode == "apc": + self.apc_scores = scores + elif mode == "mpc": + self.mpc_scores = scores + + def get_confidence_scores(self, mode: str) -> np.ndarray: + """ + Retrieves the confidence scores. + + Args: + mode: The modality of model for these confidence scores. Either 'ipc' or 'apc' or 'mpc'. + + Returns: + The confidence scores for this Med3paResults and given model modality. + """ + if mode == 'ipc': + return self.ipc_scores + elif mode == "apc": + return self.apc_scores + elif mode == "mpc": + return self.mpc_scores + + +class Med3paResults: + """ + Class to store and manage results from the MED3PA complete experiment. + """ + + def __init__(self, reference_record: Med3paRecord, test_record: Med3paRecord) -> None: + """ + Initializes the Med3paResults class. + + Args: + reference_record: The reference record for the MED3pa experiment. + test_record: The test record for the MED3pa experiment. + """ + self.reference_record = reference_record + self.test_record = test_record + self.experiment_config = {} + self.ipc_model = None + self.apc_model = None + + def set_experiment_config(self, config: Dict[str, Any]) -> None: + """ + Sets or updates the configuration for the MED3pa experiment. + Args: + config (Dict[str, Any]): A dictionary of experiment configuration. + """ + self.experiment_config.update(config) + + def set_models(self, ipc_model: IPCModel, apc_model: APCModel = None) -> None: + """ + Sets the confidence models for the Med3pa experiment. + + Args: + ipc_model (IPCModel): The IPC model to predict individualized confidence predictions. + apc_model (APCModel): The APC model to predict aggregated confidence predictions by profiles. + """ + self.ipc_model = ipc_model + self.apc_model = apc_model + + def save(self, file_path: str, save_med3paResults: bool = True) -> None: + """ + Saves the experiment results. + Args: + file_path (str): The file path to save the JSON files. + save_med3paResults (bool): Whether to save the results in a Med3paResults file. Defaults to True. + """ + results = {} + # Ensure the main directory exists + os.makedirs(file_path, exist_ok=True) + + reference_path = os.path.join(file_path, 'reference') + test_path = os.path.join(file_path, 'test') + + if self.reference_record: + self.reference_record.save(file_path=reference_path) + results['reference'] = self.reference_record.save_to_dict() + self.test_record.save(file_path=test_path) + results['test'] = self.test_record.save_to_dict() + + experiment_config_path = os.path.join(file_path, 'experiment_config.json') + with open(experiment_config_path, 'w') as file: + json.dump(self.experiment_config, file, default=to_serializable, indent=4) + + results['infoConfig'] = {'experiment_config': self.experiment_config} + + if save_med3paResults: + self.__generate_Med3paResults_from_dict(results, file_path=file_path) + + def save_models(self, file_path: str, mode: str = 'all', id: str = None) -> None: + """ + Saves the experiment ipc and apc models as .pkl files, alongside the tree structure for the test set. + Args: + file_path (str): The file path to save the pickled files. + mode (str): Defines the type of models to save, either 'ipc', 'apc', or 'all'. Default is 'all'. + id (str): Optional identifier to append to the filenames. + """ + # Ensure the main directory exists + os.makedirs(file_path, exist_ok=True) + + # Function to generate the file name with optional id + def generate_file_name(base_name, id): + return f"{id}_{base_name}" if id else base_name + + if mode == 'all': + if self.ipc_model: + ipc_model_name = generate_file_name('ipc_model.pkl', id) + ipc_path = os.path.join(file_path, ipc_model_name) + self.ipc_model.save_model(ipc_path) + if self.apc_model: + apc_model_name = generate_file_name('apc_model.pkl', id) + apc_path = os.path.join(file_path, apc_model_name) + self.apc_model.save_model(apc_path) + if self.test_record.tree: + tree_structure_name = generate_file_name('tree.json', id) + tree_structure_path = os.path.join(file_path, tree_structure_name) + self.test_record.tree.save_tree(tree_structure_path) + elif mode == 'ipc': + if self.ipc_model: + ipc_model_name = generate_file_name('ipc_model.pkl', id) + ipc_path = os.path.join(file_path, ipc_model_name) + self.ipc_model.save_model(ipc_path) + elif mode == 'apc': + if self.apc_model: + apc_model_name = generate_file_name('apc_model.pkl', id) + apc_path = os.path.join(file_path, apc_model_name) + self.apc_model.save_model(apc_path) + if self.test_record.tree: + tree_structure_name = generate_file_name('tree.json', id) + tree_structure_path = os.path.join(file_path, tree_structure_name) + self.test_record.tree.save_tree(tree_structure_path) + + def __generate_Med3paResults_from_dict(self, data: dict, file_path: str) -> None: + """ + Generates a Med3paResult file from the provided data dictionary then saves it in file_path with current time. + Args: + data (dict): A dictionary containing all the relevant data. + file_path (str): Path where to save the Med3paResult file. + + """ + file_name = f"/MED3paResults_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}".replace( + r'[^a-zA-Z0-9-_]', "") + file_content = {"loadedFiles": {}, "isDetectron": False} + + # Process data based on tabs + tabs = ["infoConfig", "reference", "test"] + for tab in tabs: + if tab in data: + file_content["loadedFiles"][tab] = data[tab] + else: + print(f"Tab {tab} not found in") + + self.__save_dict_to_file(file_content, file_path + file_name + '.MED3paResults') + + @staticmethod + def __to_string(value: Any) -> str: + """ + Converts Any object to a string. + + Args: + value (Any): Value to be converted to string. + + Returns: + str: String representation of the value. + """ + if value is None: + return 'null' + return str(value) + + def __write_list(self, file: TextIO, l: list, indent: int, inc: int) -> None: + """ + Writes a list to the file. + + Args: + file (TextIO): The file to write the dictionary to. + l (list): The dictionary to be written in the file. + indent (int): The indent level applied to the file. + inc (int): The increments applied to the indent in the file. + """ + if len(l) == 0: + file.write("[]") + else: + file.write('[\n') + for idx, item in enumerate(l): + coma_list = '' if idx == len(l) - 1 else ',' + if isinstance(item, list): + self.__write_list(file, item, indent + inc, inc) + elif isinstance(item, dict): + file.write(" " * (indent + inc)) + self.__write_dict(file, item, indent + inc, inc) + else: + file.write(' ' * (indent + inc) + '"' + self.__to_string(item) + '"') + file.write(coma_list + "\n") + file.write(' ' * indent + ']') + + def __write_dict(self, file: TextIO, d: dict, indent: int = 0, inc: int = 2) -> None: + """ + Writes a dictionary to the file. + + Args: + file (TextIO): The file to write the dictionary to. + d (dict): The dictionary to be written in the file. + indent (int): The indent level applied to the file. + inc (int): The increments applied to the indent in the file. + """ + if len(d) == 0: + file.write("{}") + else: + file.write("{\n") + for index, (key, value) in enumerate(d.items()): + file.write(f'{" " * (indent + inc)}"{key}": ') + coma = '' if index == len(d) - 1 else ',' + if isinstance(value, dict): + self.__write_dict(file, value, (indent + inc), inc) + file.write(coma + '\n') + + elif isinstance(value, list): + self.__write_list(file, value, (indent + inc), inc) + file.write(coma + '\n') + + elif isinstance(value, bool): + file.write(f'{str(value).lower()}{coma}\n') + elif isinstance(value, str): + file.write(f'"{value}"{coma}\n') + else: + file.write(f'{self.__to_string(value)}{coma}\n') + file.write(" " * indent + "}") + + def __save_dict_to_file(self, dictionary: dict, file_path: str) -> None: + """ + Saves a dictionary to a file. + + Args: + dictionary (dict): The dictionary to be saved into the file. + file_path (str): Path to the file to save the dictionary to. + + Raises: + Exception: If an error occurs while saving the dictionary. + """ + try: + with open(file_path, 'w', encoding='utf-8') as file: + self.__write_dict(file, dictionary) + + print(f"Dictionary successfully saved to {file_path}") + except Exception as e: + print(f"Error saving dictionary to {file_path}: {e}") diff --git a/MED3pa/med3pa/tree.py b/MED3pa/med3pa/tree.py new file mode 100644 index 0000000..7465a76 --- /dev/null +++ b/MED3pa/med3pa/tree.py @@ -0,0 +1,302 @@ +""" +Manages the tree representation for the APC model. It includes the ``TreeRepresentation`` class which handles the construction and manipulation of decision trees +and ``TreeNode`` class that represents a node in the tree. +This module is crucial for profiling aggregated data and extracting valuable insights +""" + +import json +import numpy as np +from pandas import DataFrame, Series +from typing import Union, Any, Dict, List + +from MED3pa.models.concrete_regressors import DecisionTreeRegressorModel +from MED3pa.med3pa.profiles import Profile + + +def to_serializable(obj: Any, additional_arg: Any = None) -> Any: + """ + Convert an object to a JSON-serializable format. + Args: + obj (Any): The object to convert. + additional_arg (Any): Additional arguments to serialize + Returns: + Any: The JSON-serializable representation of the object. + """ + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, (np.integer, np.floating)): + return obj.item() + if isinstance(obj, Profile): + if additional_arg is not None: + return obj.to_dict(additional_arg) + else: + return obj.to_dict() + if isinstance(obj, _TreeNode): + return obj.to_dict() + if isinstance(obj, dict): + return {k: to_serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [to_serializable(v) for v in obj] + return obj + + +class TreeRepresentation: + """ + Represents the structure of a decision tree for a given set of features. + """ + + def __init__(self, features: List) -> None: + """ + Initializes the TreeRepresentation with a list of feature names. + + Args: + features (List[str]): List of feature names used in the decision tree. + """ + self.features = features + self.head = None + self.nb_nodes = 0 + + def build_tree(self, dtr: DecisionTreeRegressorModel, X: DataFrame, y: np.ndarray | Series, node_id: int = 0, + path: List = None) -> '_TreeNode': + """ + Recursively builds the tree representation starting from the specified node. + + Args: + dtr (DecisionTreeRegressorModel): Trained decision tree regressor model. + X (DataFrame): Training data observations. + y (Series): Training data labels. + node_id (int): Node ID to start building from. Defaults to 0. + path (Optional[List[str]]): Path to the current node. Defaults to ['*']. + + Returns: + _TreeNode: The root node of the tree representation. + """ + if path is None: + path = ['*'] + + self.nb_nodes += 1 + left_child = dtr.model.tree_.children_left[node_id] + right_child = dtr.model.tree_.children_right[node_id] + + node_value = y.mean() + node_max = y.max() + node_samples_ratio = dtr.model.tree_.n_node_samples[node_id] / dtr.model.tree_.n_node_samples[0] * 100 + + # If we are at a leaf + if left_child == -1: + curr_node = _TreeNode(value=node_value, value_max=node_max, samples_ratio=node_samples_ratio, + node_id=self.nb_nodes, path=path) + return curr_node + + node_thresh = dtr.model.tree_.threshold[node_id] + node_feature_id = dtr.model.tree_.feature[node_id] + node_feature = self.features[node_feature_id] + + curr_path = list(path) # Copy the current path to avoid modifying the original list + curr_node = _TreeNode(value=node_value, value_max=node_max, samples_ratio=node_samples_ratio, + threshold=node_thresh, feature=node_feature, feature_id=node_feature_id, + node_id=self.nb_nodes, path=curr_path) + + # Update paths for child nodes + left_path = curr_path + [f"{node_feature} <= {node_thresh}"] + right_path = curr_path + [f"{node_feature} > {node_thresh}"] + + curr_node.c_left = self.build_tree(dtr, X=X.loc[X[node_feature] <= node_thresh], + y=y[X[node_feature] <= node_thresh], + node_id=left_child, path=left_path) + curr_node.c_right = self.build_tree(dtr, X=X.loc[X[node_feature] > node_thresh], + y=y[X[node_feature] > node_thresh], + node_id=right_child, path=right_path) + + return curr_node + + def get_all_profiles(self, min_ca: float = 0, min_samples_ratio: float = 0) -> List: + """ + Retrieves all profiles from the tree that meet the minimum criteria for value and sample ratio. + + Args: + min_ca (float): Minimum value threshold for profiles. Defaults to 0. + min_samples_ratio (float): Minimum sample ratio threshold for profiles. Defaults to 0. + + Returns: + List[Profile]: A list of Profile instances meeting the specified criteria. + """ + if self.head is None: + raise ValueError("Tree has not been built yet.") + profiles = self.head.get_profile(min_samples_ratio=min_samples_ratio, min_ca=min_ca) + return profiles + + def get_all_nodes(self) -> List[Dict]: + """ + Retrieves all nodes from the tree with their paths. + + Returns: + List[dict]: A list of dictionaries representing nodes with their paths. + + Raises: + ValueError: If the tree has not been built yet. + """ + if self.head is None: + raise ValueError("Tree has not been built yet.") + return self.head.get_all_nodes() + + def save_tree(self, file_path: str) -> None: + """ + Saves the tree structure to a JSON file. + + Args: + file_path (str): The file path where the tree structure will be saved. + """ + if self.head is None: + raise ValueError("Tree has not been built yet.") + + tree_dict = self.to_dict() + with open(file_path, 'w') as file: + json.dump(tree_dict, file, default=to_serializable, indent=4) + + def to_dict(self) -> Dict: + """ + Converts the tree structure to a dictionary. + + Returns: + Dict: A dictionary representing the tree structure. + """ + if self.head is None: + raise ValueError("Tree has not been built yet.") + + tree_dict = self.head.to_dict() + tree_dict['features'] = self.features + return tree_dict + + +class _TreeNode: + """ + Represents a node in the tree structure. + """ + + def __init__(self, value: float = None, value_max: float = None, samples_ratio: float = None, + threshold: float = None, feature: str = None, feature_id: int = None, node_id: int = 0, + path: List = None) -> None: + """ + Initializes a _TreeNode object. + + Args: + value (float): The average value at the node. + value_max (float): The maximum value at the node. + samples_ratio (float): The percentage of total samples present at the node. + threshold (Optional[float]): The threshold used for splitting at this node. Defaults to None. + feature (Optional[str]): The feature used for splitting at this node. Defaults to None. + feature_id (Optional[int]): The identifier of the feature used for splitting. Defaults to None. + node_id (int): Unique identifier for the node. Defaults to 0. + path (Optional[List[str]]): The path from the root to this node. Defaults to an empty list. + """ + self.c_left = None + self.c_right = None + self.value = value + self.value_max = value_max + self.samples_ratio = samples_ratio + self.threshold = threshold + self.feature = feature + self.feature_id = feature_id + self.node_id = node_id + self.path = path if path is not None else [] + + def assign_node(self, X: Union[DataFrame, Series]) -> float: + """ + Assigns a value to a node based on input observations, navigating the tree until a leaf node is reached. + + Args: + X (Union[DataFrame, Series]): Input observations used to navigate and determine the value at a node. + + Returns: + float: The value assigned based on the input observations and the structure of the tree. + + Raises: + TypeError: If the input X is neither a pandas DataFrame nor a pandas Series. + """ + # Check if the current node is a leaf node + if self.c_left is None and self.c_right is None: + return self.value + + if isinstance(X, DataFrame): + X_value = X[self.feature].values[0] + elif isinstance(X, Series): + X_value = X[self.feature] + else: + raise TypeError( + f"Parameter X is of type {type(X)}, but it must be of type 'pandas.DataFrame' or 'pandas.Series'.") + + if X_value <= self.threshold: # If node split condition is true, then left children + c_node = self.c_left + else: + c_node = self.c_right + + return c_node.assign_node(X) + + def get_profile(self, min_samples_ratio: float, min_ca: float) -> List: + """ + Retrieves profiles from the subtree rooted at this node that meet the specified criteria. + + Args: + min_samples_ratio (float): The minimum sample ratio a node must have to be included in the output profiles. + min_ca (float): The minimum value a node must have to be included in the output profiles. + + Returns: + List[Profile]: A list of Profile instances representing nodes that meet the criteria. + + """ + profiles = [] + if self.c_left is not None and self.c_left.samples_ratio >= min_samples_ratio: + # Recursively retrieve profiles from the left child + profiles.extend(self.c_left.get_profile(min_samples_ratio, min_ca)) + + if self.c_right is not None and self.c_right.samples_ratio >= min_samples_ratio: + # Recursively retrieve profiles from the right child + profiles.extend(self.c_right.get_profile(min_samples_ratio, min_ca)) + + # Check if the current node meets the criteria + if self.samples_ratio >= min_samples_ratio and self.value >= min_ca: + profile = Profile(node_id=self.node_id, path=self.path) + profiles.append(profile) + + return profiles + + def get_all_nodes(self) -> List: + """ + Retrieves all nodes in the subtree rooted at this node with their paths. + + Returns: + List[dict]: A list of dictionaries representing nodes with their paths. + """ + nodes = [{ + 'node_id': self.node_id, + 'path': self.path + }] + + if self.c_left is not None: + nodes.extend(self.c_left.get_all_nodes()) + + if self.c_right is not None: + nodes.extend(self.c_right.get_all_nodes()) + + return nodes + + def to_dict(self) -> Dict: + """ + Converts the node and its children to a dictionary. + + Returns: + Dict: A dictionary representation of the node and its children. + """ + node_dict = { + 'threshold': self.threshold, + 'feature': self.feature, + 'feature_id': self.feature_id, + 'node_id': self.node_id, + 'path': self.path + } + if self.c_left is not None: + node_dict['c_left'] = self.c_left.to_dict() + if self.c_right is not None: + node_dict['c_right'] = self.c_right.to_dict() + return node_dict diff --git a/MED3pa/med3pa/uncertainty.py b/MED3pa/med3pa/uncertainty.py new file mode 100644 index 0000000..682c469 --- /dev/null +++ b/MED3pa/med3pa/uncertainty.py @@ -0,0 +1,119 @@ +"""This module handles the computation of uncertainty metrics. It defines an abstract base class +``UncertaintyMetric`` and concrete implementations such as ``AbsoluteError`` for calculating uncertainty based on the +difference between predicted probabilities and actual outcomes. An ``UncertaintyCalculator`` class is provided, +which allows users to specify which uncertainty metric to use, thereby facilitating the use of customized uncertainty +metrics for different analytical needs.""" + +import numpy as np +from abc import ABC, abstractmethod + + +class UncertaintyMetric(ABC): + """ + Abstract base class for uncertainty metrics. Defines the structure that all uncertainty metrics should follow. + """ + @staticmethod + @abstractmethod + def calculate(x: np.ndarray, predicted_prob: np.ndarray, y_true: np.ndarray) -> np.ndarray: + """ + Calculates the uncertainty metric based on input observations, predicted probabilities, and true labels. + + Args: + x (np.ndarray): Input observations. + predicted_prob (np.ndarray): Predicted probabilities by the model. + y_true (np.ndarray): True labels. + + Returns: + np.ndarray: An array of uncertainty values for each prediction. + """ + pass + + +class AbsoluteError(UncertaintyMetric): + """ + Concrete implementation of the UncertaintyMetric class using absolute error. + """ + @staticmethod + def calculate(x: np.ndarray, predicted_prob: np.ndarray, y_true: np.ndarray) -> np.ndarray: + """ + Calculates the absolute error between predicted probabilities and true labels, providing a measure of + prediction accuracy. + + Args: + x (np.ndarray): Input features (not used in this metric but included for interface consistency). + predicted_prob (np.ndarray): Predicted probabilities. + y_true (np.ndarray): True labels. + + Returns: + np.ndarray: Absolute errors between predicted probabilities and true labels. + """ + return 1 - np.abs(y_true - predicted_prob) + + +class SigmoidalError(UncertaintyMetric): + """ + Concrete implementation of the UncertaintyMetric class using Sigmoidal error. + """ + @staticmethod + def calculate(x: np.ndarray, predicted_prob: np.ndarray, y_true: np.ndarray, threshold=0.5) -> np.ndarray: + """ + Calculates the Sigmoidal error between predicted probabilities and true labels, providing a measure of + prediction accuracy. + + Args: + x (np.ndarray): Input features (not used in this metric but included for interface consistency). + predicted_prob (np.ndarray): Predicted probabilities. + y_true (np.ndarray): True labels. + threshold (float): Classification threshold + + Returns: + np.ndarray: Sigmoidal errors between predicted probabilities and true labels. + """ + return 1 / ( + 1 + np.exp(10 * np.log(3) * (np.abs(y_true - predicted_prob) - np.abs(threshold - y_true)))) + + +class UncertaintyCalculator: + """ + Class for calculating uncertainty using a specified uncertainty metric. + """ + metric_mapping = { + 'absolute_error': AbsoluteError, + 'sigmoidal_error': SigmoidalError, + } + + def __init__(self, metric_name: str) -> None: + """ + Initializes the UncertaintyCalculator with a specific uncertainty metric. + + Args: + metric_name (str): The name of the uncertainty metric to use for calculations. + """ + if metric_name not in self.metric_mapping: + raise ValueError(f"Unrecognized metric name: {metric_name}. Available metrics: {list(self.metric_mapping)}") + + self.metric = self.metric_mapping[metric_name] + + def calculate_uncertainty(self, x: np.ndarray, predicted_prob: np.ndarray, y_true: np.ndarray) -> np.ndarray: + """ + Calculates uncertainty for a set of predictions using the configured uncertainty metric. + + Args: + x (np.ndarray): Input features. + predicted_prob (np.ndarray): Predicted probabilities. + y_true (np.ndarray): True labels. + + Returns: + np.ndarray: Uncertainty values for each prediction, computed using the specified metric. + """ + return self.metric.calculate(x, predicted_prob, y_true) + + @classmethod + def supported_metrics(cls) -> list: + """ + Returns a list of supported uncertainty metrics. + + Returns: + list: A list of strings representing the names of the supported uncertainty metrics. + """ + return list(cls.metric_mapping) diff --git a/MED3pa/models/__init__.py b/MED3pa/models/__init__.py new file mode 100644 index 0000000..7958a52 --- /dev/null +++ b/MED3pa/models/__init__.py @@ -0,0 +1,7 @@ +from .abstract_models import * +from .base import * +from .concrete_classifiers import * +from .concrete_regressors import * +from .classification_metrics import * +from .regression_metrics import * +from .factories import * diff --git a/MED3pa/models/abstract_metrics.py b/MED3pa/models/abstract_metrics.py new file mode 100644 index 0000000..790f026 --- /dev/null +++ b/MED3pa/models/abstract_metrics.py @@ -0,0 +1,38 @@ +""" +The ``abstract_metrics.py`` module defines the ``EvaluationMetric`` abstract base class, +providing a standard interface for calculating metric values for model evaluations. +""" + +from abc import ABC, abstractmethod + + +class EvaluationMetric(ABC): + """ + Abstract base class for all evaluation metrics. This class provides a standardized interface for calculating + metric values across different types of tasks, ensuring consistency and reusability. + """ + + @classmethod + @abstractmethod + def get_metric(cls, metric_name: str): + """ + Get the metric function based on the metric name. + + Args: + metric_name (str): The name of the metric. + + Returns: + function: The function corresponding to the metric. + """ + pass + + @classmethod + @abstractmethod + def supported_metrics(cls) -> list: + """ + Get a list of supported regression metrics. + + Returns: + list: A list of supported regression metrics. + """ + pass diff --git a/MED3pa/models/abstract_models.py b/MED3pa/models/abstract_models.py new file mode 100644 index 0000000..6513b6f --- /dev/null +++ b/MED3pa/models/abstract_models.py @@ -0,0 +1,692 @@ +"""The abstract_models.py module defines core abstract classes that serve as the foundation for model management in +the system. It includes ``Model``, which standardizes basic operations like evaluation and parameter validation +across all models. It also introduces specialized abstract classes such as ``ClassificationModel`` and +``RegressionModel``, each adapting these operations to specific needs of classification and regression tasks.""" + +import json +import matplotlib.pyplot as plt +import numpy as np +import os +import pandas as pd +import pickle +from abc import ABC, abstractmethod +from copy import deepcopy +from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.calibration import CalibratedClassifierCV, CalibrationDisplay +from sklearn.metrics import roc_curve, precision_recall_curve, RocCurveDisplay +from typing import Any, Dict, List, Optional, Self, Union + +from MED3pa.models.data_strategies import DataPreparingStrategy +from MED3pa.models.classification_metrics import ClassificationEvaluationMetrics +from MED3pa.models.regression_metrics import RegressionEvaluationMetrics + + +class Model(ABC, BaseEstimator): + """ + An abstract base class for all models, defining a common API for model operations such as evaluation and parameter validation. + + Attributes: + model (Any): The underlying model instance. + model_class (type): The class type of the underlying model instance. + params (dict): The params used for initializing the model. + data_preparation_strategy (DataPreparingStrategy): Strategy for preparing data before training or evaluation. + pickled_model (Boolean): A boolean indicating whether the model has been loaded from a pickled file. + """ + + def __init__(self, random_state: int = None, verbose: bool = False) -> None: + """ + Initializes the Model instance. + + Args: + random_state (int): The random state used to train the model. + verbose (Boolean): Control of output messages. + + """ + super().__init__() + self.model = None + self.model_class = None + self.params = None + self.data_preparation_strategy = None + self.pickled_model = False + self.file_path = None + self._random_state = random_state + self.verbose = verbose + + def get_model(self) -> Any: + """ + Retrieves the underlying model instance, which is typically a machine learning model object. + + Returns: + Any: The underlying model instance if set, None otherwise. + """ + return self.model + + def get_path(self) -> str: + """ + Retrieves the file path of the model if it has been loaded from a pickled file. + + Returns: + str: The file path of the model if it has been loaded from a pickled file, None otherwise. + """ + return self.file_path + + def get_model_type(self) -> Optional[str]: + """ + Retrieves the class type of the underlying model instance, which indicates the specific + implementation of the model used. + + Returns: + Optional[str]: The class of the model if set, None otherwise. + """ + return self.model_class.__name__ if self.model_class else None + + def get_data_strategy(self) -> Optional[str]: + """ + Retrieves the data preparation strategy associated with the model. This strategy handles + how data should be formatted before being passed to the model for training or evaluation. + + Returns: + Optional[str]: The name of the current data preparation strategy if set, None otherwise. + """ + return self.data_preparation_strategy.__class__.__name__ if self.data_preparation_strategy else None + + def get_params(self, deep: bool = False) -> Dict[str, Any]: + """ + Retrieves the underlying model's parameters. + + Returns: + Dict[str, Any]: the model's parameters. + """ + return self.params + + def is_pickled(self) -> bool: + """ + Returns whether the model has been loaded from a pickled file. + + Returns: + Boolean: has the model been loaded from a pickled file. + """ + return self.pickled_model + + def set_model(self, model: Any) -> None: + """ + Sets the underlying model instance and updates the model class to match the type of the given model. + + Args: + model (Any): The model instance to be set. + """ + self.model = model + self.model_class = type(model) + + def set_params(self, params: dict = None, **kwargs) -> Self: + """ + Sets the parameters for the model. These parameters are typically used for model initialization or configuration. + + Args: + params (Dict[str, Any]): A dictionary of parameters for the model. + + Returns: + Self: The model instance with the updated parameters. + """ + if params is None: + params = {} + if kwargs is not None: + params.update(kwargs) + self.params = params + return self + + def set_file_path(self, file_path: str) -> None: + """ + Sets the file path of the model. + + Args: + file_path (str): the file path of the model. + """ + self.file_path = file_path + + def update_params(self, params: dict) -> None: + """ + Updates the current model parameters by merging new parameter values from the given dictionary. + This method allows for dynamic adjustment of model configuration during runtime. + + Args: + params (Dict[str, Any]): A dictionary containing parameter names and values to be updated. + """ + self.params.update(params) + + def set_data_strategy(self, strategy: DataPreparingStrategy) -> None: + """ + Sets the underlying model's data preparation strategy. + + Args: + strategy (DataPreparingStrategy): strategy to be used to prepare the data for training, validation...etc. + """ + self.data_preparation_strategy = strategy + + def get_info(self) -> Dict[str, Any]: + """ + Retrieves detailed information about the model. + + Returns: + Dict[str, Any]: A dictionary containing information about the model's type, parameters, + data preparation strategy, and whether it's a pickled model. + """ + return { + "model": self.__class__.__name__, + "model_type": self.get_model_type(), + "params": self.get_params(), + "data_preparation_strategy": self.get_data_strategy() if self.get_data_strategy() else None, + "pickled_model": self.is_pickled(), + "file_path": self.get_path() + } + + def save(self, path: str) -> None: + """ + Saves the model instance as a pickled file and the parameters as a JSON file within the specified directory. + + Args: + path (str): The directory path where the model and parameters will be saved. + """ + # Create the directory if it does not exist + if not os.path.exists(path): + os.makedirs(path) + + # Define file paths + model_path = os.path.join(path, 'model_instance.pkl') + params_path = os.path.join(path, 'model_info.json') + + # Save the model as a pickled file + with open(model_path, 'wb') as model_file: + pickle.dump(self.model, model_file) + + # Save the parameters as a JSON file + with open(params_path, 'w') as params_file: + json.dump(self.get_info(), params_file) + + @abstractmethod + def evaluate(self, X: np.ndarray, y: np.ndarray, eval_metrics: List[str], print_results: bool = False + ) -> Dict[str, float]: + """ + Evaluates the model using specified metrics. + + Args: + X (np.ndarray): observations for evaluation. + y (np.ndarray): True labels for evaluation. + eval_metrics (List[str]): Metrics to use for evaluation. + print_results (bool, optional): Whether to print the evaluation results. Defaults to False. + + Returns: + Dict[str, float]: A dictionary with metric names and their evaluated scores. + """ + pass + + @staticmethod + def print_evaluation_results(results: Dict[str, float]) -> None: + """ + Prints the evaluation results in a formatted manner. + + Args: + results (Dict[str, float]): A dictionary with metric names and their evaluated scores. + """ + print("Evaluation Results:") + for metric, value in results.items(): + print(f"{metric}: {value:.2f}") + + @staticmethod + def validate_params(params: Dict[str, Any], valid_param_sets: List[set]) -> Dict[str, Any]: + """ + Validates the model parameters against a list of valid parameter sets. + + Args: + params (Dict[str, Any]): Parameters to validate. + valid_param_sets (List[set]): A list of sets containing valid parameter names. + + Returns: + Dict[str, Any]: Validated parameters. + + Raises: + ValueError: If any invalid parameters are found. + """ + combined_valid_params = set().union(*valid_param_sets) + invalid_params = [k for k in params.keys() if k not in combined_valid_params] + if invalid_params: + raise ValueError(f"Invalid parameters found: {invalid_params}") + return {k: v for k, v in params.items() if k in combined_valid_params} + + +class ClassificationModel(Model, ClassifierMixin): + """ + Abstract base class for classification models, extending the generic Model class with additional + classification-specific methods. + """ + + def __init__(self, objective: str = 'binary:logistic', class_weighting: bool = False, random_state: int = None, + verbose: bool = False): + """ + Initializes the ClassificationModel instance. + + Args: + objective (str): The objective of the model, used by certain model packages. + class_weighting (bool): Whether to add class weighting correction during model training. Default: False + random_state (Optional int): The random state used to train the model. + verbose (Boolean): Control of output messages. + + """ + super().__init__(random_state=random_state, verbose=verbose) + self._objective = objective + self._class_weighting = class_weighting + self._threshold = 0.5 + self._calibration = None + self.classes_ = np.array([0, 1]) # To allow model calibration. The CalibratedClassifierCV uses this variable to + # ensure that the model to be calibrated has been fitted. + + @property + def threshold(self) -> float: + """ + Returns the threshold for the classification. + + Returns: + float: The threshold for the classification. + """ + return self._threshold + + @staticmethod + def balance_train_weights(y_train: np.ndarray) -> np.ndarray: + """ + Balances the training weights based on the class distribution in the training data. + + Args: + y_train (np.ndarray): Labels for training. + + Returns: + np.ndarray: Balanced training weights. + + Raises: + AssertionError: If balancing is attempted on non-binary classification data. + """ + _, counts = np.unique(y_train, return_counts=True) + assert len(counts) == 2, 'Only binary classification is supported' + c_neg, c_pos = counts[0], counts[1] + pos_weight, neg_weight = 2 * c_neg / (c_neg + c_pos), 2 * c_pos / (c_neg + c_pos) + train_weights = np.array([pos_weight if label == 1 else neg_weight for label in y_train]) + return train_weights + + def calibrate_model(self, *, y_true: Union[pd.Series, np.ndarray], data: pd.DataFrame, method: str = 'sklearn' + ) -> None: + """ + Calibrates the model based on the provided data. + + Args: + y_true (Union[pd.Series, np.ndarray]): Labels of data used to calibrate the model. + data (pd.DataFrame): Data used to calibrate the model. + method (str): Method used to calibrate the model. Possible values are ['sklearn']. Default: 'sklearn' + """ + if method == 'sklearn': + calibration = CalibratedClassifierCV(estimator=deepcopy(self), method='sigmoid', cv='prefit') + calibration.fit(data, y_true) + else: + raise NotImplementedError + self._calibration = calibration + + @abstractmethod + def fit(self, x_train: Union[np.ndarray, pd.DataFrame], y_train: Union[np.ndarray, pd.Series], + training_parameters: Optional[Dict[str, Any]], balance_train_classes: bool, weights: np.ndarray = None, + *params) -> None: + """ + Fits the classification model using provided training and validation data. + This method is functionally equivalent to train. Subclasses may override either fit or train, + and the other will default to the same behavior. + + Args: + x_train (Union[np.ndarray, pd.DataFrame]): observations for training. + y_train (Union[np.ndarray, pd.Series]): Labels for training. + training_parameters (Dict[str, Any], optional): Additional training parameters. + balance_train_classes (bool): Whether to balance the training classes. + weights (Optional[np.ndarray], optional): Weights for the training data. + *params : Additional training parameters for specific models. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + pass + + @abstractmethod + def train(self, x_train: Union[np.ndarray, pd.DataFrame], y_train: Union[np.ndarray, pd.Series], + training_parameters: Optional[Dict[str, Any]], balance_train_classes: bool, weights: np.ndarray = None, + *params) -> None: + """ + Trains the classification model using provided training and validation data. + This method is functionally equivalent to fit. Subclasses may override either fit or train, + and the other will default to the same behavior. + + Args: + x_train (Union[np.ndarray, pd.DataFrame]): observations for training. + y_train (Union[np.ndarray, pd.Series]): Labels for training. + training_parameters (Dict[str, Any], optional): Additional training parameters. + balance_train_classes (bool): Whether to balance the training classes. + weights (Optional[np.ndarray], optional): Weights for the training data. + *params : Additional training parameters for specific models. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + pass + + def __init_subclass__(cls, **kwargs): + """ + Initializes the fit and train methods so only one needs to be defined in child classes + + Raises: + NotImplementedError: At least one of the 'fit' or 'train' methods must be defined. + """ + super().__init_subclass__(**kwargs) + # Check if either fit or train has been overridden in the subclass + if 'fit' not in cls.__dict__ and 'train' not in cls.__dict__: + raise NotImplementedError(f"Class {cls.__name__} must implement either `fit` or `train`.") + + if 'fit' not in cls.__dict__: + cls.fit = cls.train + if 'train' not in cls.__dict__: + cls.train = cls.fit + + def train_to_disagree(self, x_train: np.ndarray, y_train: np.ndarray, + x_test: np.ndarray, y_test: np.ndarray, + training_parameters: Optional[Dict[str, Any]], balance_train_classes: bool) -> None: + """ + Trains the classification model using provided training and validation data. Specific to Detectron experiments. + + Args: + x_train (np.ndarray): observations for training. + y_train (np.ndarray): Labels for training. + x_test (np.ndarray): observations for testing. + y_test (np.ndarray): Labels for testing. + training_parameters (Dict[str, Any], optional): Additional training parameters. + balance_train_classes (bool): Whether to balance the training classes. + + """ + N = len(y_test) + + # if additional training parameters are provided and both classes are present + if training_parameters and len(np.unique(y_train)) == 2: + training_weights = self.balance_train_weights(y_train) if balance_train_classes else ( + training_parameters.get('training_weights', np.ones_like(y_train))) + else: + training_weights = np.ones_like(y_train) + + # prepare the data for training + data = np.concatenate([x_train, x_test]) + label = np.concatenate([y_train, 1 - y_test]) + weights = np.concatenate([training_weights, 1 / (N + 1) * np.ones(N)]) + + self.train(data, label, training_parameters=training_parameters, + balance_train_classes=balance_train_classes, weights=weights) + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Makes label predictions for the given input observations. + + Args: + X (np.ndarray): observations for prediction. + + Returns: + np.ndarray: The predicted labels or probabilities. + + """ + return (self.predict_proba(X)[:, 1] >= self._threshold).astype(int) + + @abstractmethod + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """ + Makes probability predictions for the given input observations. + + Args: + X (np.ndarray): observations for prediction. + + Returns: + np.ndarray: The predicted labels or probabilities. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + pass + + def evaluate(self, X: np.ndarray, y: np.ndarray, eval_metrics: Union[str, List[str]] = None, + print_results: bool = False) -> Dict[str, float]: + """ + Evaluates the model using specified metrics. + + Args: + X (np.ndarray): Features for evaluation. + y (np.ndarray): True labels for evaluation. + eval_metrics (List[str]): Metrics to use for evaluation. + print_results (bool, optional): Whether to print the evaluation results. + + Returns: + Dict[str, float]: A dictionary with metric names and their evaluated scores. + + Raises: + ValueError: If the model has not been trained before evaluation. + """ + if self.model is None: + raise ValueError("Model must be trained before evaluation.") + + if eval_metrics is None: + eval_metrics = ClassificationEvaluationMetrics.supported_metrics() + + # Ensure eval_metrics is a list + if isinstance(eval_metrics, str): + eval_metrics = [eval_metrics] + + probs = self.predict_proba(X)[:, 1] + preds = self.predict(X) + evaluation_results = {} + for metric_name in eval_metrics: + metric_function = ClassificationEvaluationMetrics.get_metric(metric_name) + if metric_function: + evaluation_results[metric_name] = metric_function(y_true=y, y_prob=probs, y_pred=preds) + else: + print(f"Error: The metric '{metric_name}' is not supported.") + + if print_results: + self.print_evaluation_results(results=evaluation_results) + return evaluation_results + + def plot_probability_distribution(self, X: np.ndarray, y: np.ndarray, save_path: str = None) -> None: + """ + Plots the predicted probability distributions for each class in a binary classification model. + + Args: + X (np.ndarray): Features for evaluation. + y (np.ndarray): True labels for evaluation. + save_path (str): Optional path to save the plot. Defaults to None, which displays the plot in terminal. + """ + plt.clf() + probas = self.predict_proba(X)[:, 1] + + # Separate probabilities for each class + class_0_probas = probas[y == 0] + class_1_probas = probas[y == 1] + + # Plot the probability distributions + plt.hist(class_1_probas, bins=20, alpha=0.5, label='Class 1', color='blue') + plt.hist(class_0_probas, bins=20, alpha=0.5, label='Class 0', color='red') + plt.xlabel('Predicted Probability') + plt.ylabel('Frequency (%)') + plt.title('Real Class Probability Distribution for Each Class') + plt.legend(loc='upper center') + if save_path: + plt.savefig(save_path) + else: + plt.show() + + def plot_roc_curve(self, X: np.ndarray, target: np.ndarray, save_path: str = None) -> None: + """ + Plots the AUROC curve of the model. + + Args: + X (np.ndarray): Features for evaluation. + target (np.ndarray): True labels for evaluation. + save_path (str): Optional path to save the plot. Defaults to None, which displays the plot in terminal. + """ + plt.clf() + predictions = self.predict_proba(X)[:, 1] + fpr, tpr, threshold = roc_curve(target, predictions) + roc_auc = roc_auc_score(target, predictions) + display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, + estimator_name=type(self).__name__) + display.plot() + + if save_path: + plt.savefig(save_path) + else: + plt.show() + + def show_calibration(self, X: np.ndarray, target: np.ndarray, save_path: str = None) -> None: + """ + Plots the calibration curve of the model. + + Args: + X (np.ndarray): Features for evaluation. + target (np.ndarray): True labels for evaluation. + save_path (str): Optional path to save the plot. Defaults to None, which displays the plot in terminal. + """ + plt.clf() + predicted_prob = self.predict_proba(X)[:, 1] + CalibrationDisplay.from_predictions(target, predicted_prob) + if save_path: + plt.savefig(save_path) + else: + plt.show() + + @staticmethod + def _optimal_threshold_auc(target: np.ndarray, predicted: np.ndarray) -> float: + """ + Calculates the optimal binary classification threshold using the AUROC method. + + Args: + target (np.ndarray): True labels for evaluation. + predicted (np.ndarray): Labels predicted by the model. + + Returns: + float: The optimal binary classification threshold using the AUROC method. + """ + fpr, tpr, threshold = roc_curve(target, predicted) + i = np.arange(len(tpr)) + roc = pd.DataFrame({'tf': pd.Series(tpr - (1 - fpr), index=i), 'threshold': pd.Series(threshold, index=i)}) + roc_t = roc.iloc[roc.tf.abs().argsort()[:1]] + + return list(roc_t['threshold'])[0] + + @staticmethod + def _optimal_threshold_auprc(target: np.ndarray, predicted: np.ndarray) -> float: + """ + Calculates the optimal binary classification threshold using the AUPRC method. + + Args: + target (np.ndarray): True labels for evaluation. + predicted (np.ndarray): Labels predicted by the model. + + Returns: + float: The optimal binary classification threshold using the AUPRC method. + """ + precision, recall, threshold = precision_recall_curve(target, predicted) + # Remove last element + precision = precision[:-1] + recall = recall[:-1] + + i = np.arange(len(recall)) + roc = pd.DataFrame({'tf': pd.Series(precision * recall, index=i), 'threshold': pd.Series(threshold, index=i)}) + roc_t = roc.iloc[roc.tf.abs().argsort()[:1]] + + return list(roc_t['threshold'])[0] + + +class RegressionModel(Model): + """ + Abstract base class for regression models, providing a framework for training and prediction in regression tasks. + """ + + @abstractmethod + def train(self, x_train: np.ndarray, y_train: np.ndarray, + training_parameters: Optional[Dict[str, Any]]) -> None: + """ + Trains the regression model using provided training and validation data. + + Args: + x_train (np.ndarray): observations for training. + y_train (np.ndarray): Labels for training. + training_parameters (Dict[str, Any], optional): Additional training parameters. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + pass + + @abstractmethod + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Makes predictions for the given input observations. + + Args: + X (np.ndarray): observations for prediction. + + Returns: + np.ndarray: The predicted values. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + pass + + def evaluate(self, X: np.ndarray, y: np.ndarray, eval_metrics: List[str], print_results: bool = False + ) -> Dict[str, float]: + """ + Evaluates the model using specified metrics. + + Args: + X (np.ndarray): observations for evaluation. + y (np.ndarray): True labels for evaluation. + eval_metrics (List[str]): Metrics to use for evaluation. + print_results (bool, optional): Whether to print the evaluation results. + + Returns: + Dict[str, float]: A dictionary with metric names and their evaluated scores. + + Raises: + ValueError: If the model has not been trained before evaluation. + """ + if self.model is None: + raise ValueError("Model must be trained before evaluation.") + + predictions = self.predict(X) + evaluation_results = {} + + for metric_name in eval_metrics: + metric_function = RegressionEvaluationMetrics.get_metric(metric_name) + if metric_function: + evaluation_results[metric_name] = metric_function(y, predictions) + else: + print(f"Error: The metric '{metric_name}' is not supported.") + + if print_results: + self.print_evaluation_results(results=evaluation_results) + + return evaluation_results + + def _ensure_numpy_arrays(self, observations: Any, labels: Optional[np.ndarray] = None) -> tuple: + """ + Ensures that the input data is converted to NumPy array format, using the defined data preparation strategy. + This method is used internally to standardize input data before training, predicting, or evaluating. + + Args: + observations (Any): observations data, which can be in various formats like lists, Pandas DataFrames, or already in NumPy arrays. + labels (np.ndarray, optional): Labels data, similar to observations in that it can be in various formats. If labels are not provided, + only observations are converted and returned. + + Returns: + tuple: The observations and labels (if provided) as NumPy arrays. If labels are not provided, labels in the tuple will be None. + """ + if not isinstance(observations, np.ndarray) or (labels is not None and not isinstance(labels, np.ndarray)): + return self.data_preparation_strategy.execute(observations, labels)[:2] + else: + return observations, labels diff --git a/MED3pa/models/base.py b/MED3pa/models/base.py new file mode 100644 index 0000000..c664dac --- /dev/null +++ b/MED3pa/models/base.py @@ -0,0 +1,133 @@ +"""This module introduces a singleton manager that manages the instantiation and cloning of a base model, +which is particularly useful for applications like ``med3pa`` where a consistent reference model is +necessary. It employs the **Singleton and Prototype** design patterns to ensure that the base model is instantiated +once and can be cloned without reinitialization.""" + +import pickle +from io import BytesIO +from typing import Any, Dict, Optional + +from .abstract_models import Model + + +class BaseModelManager: + """ + Singleton manager class for the base model. ensures the base model is set only once. + """ + __baseModel = None + _threshold = 0.5 + + def __init__(self, model: Optional[Model | Any] = None): + """ + Initializes the BaseModelManager instance. + + Args: + model (Optional[Model | Any]): The base model to be used. + """ + self.set_base_model(model) + + def set_base_model(self, model: Model | Any): + """ + Sets the base model for the manager, ensuring Singleton behavior. + + Parameters: + model (Model | Any): The model to be set as the base model. + + Raises: + TypeError: If the base model has already been initialized. + """ + if self.__baseModel is None: + self.__baseModel = model + else: + raise TypeError("The Base Model has already been initialized") + + def get_instance(self) -> Model: + """ + Returns the instance of the base model, ensuring Singleton access. + + Returns: + The base model instance. + + Raises: + TypeError: If the base model has not been initialized yet. + """ + if self.__baseModel is None: + raise TypeError("The Base Model has not been initialized yet") + return self.__baseModel + + def clone_base_model(self) -> Model: + """ + Creates and returns a deep clone of the base model, following the Prototype pattern. + + This method uses serialization and deserialization to clone complex model attributes, + allowing for independent modification of the cloned model. + + Returns: + A cloned instance of the base model. + + Raises: + TypeError: If the base model has not been initialized yet. + """ + if self.__baseModel is None: + raise TypeError("The Base Model has not been initialized and cannot be cloned") + else: + cloned_model = type(self.__baseModel)() + # Serialize and deserialize the entire base model to create a deep clone. + if hasattr(self.__baseModel, 'model') and self.__baseModel.model is not None: + buffer = BytesIO() + pickle.dump(self.__baseModel.model, buffer) + buffer.seek(0) + cloned_model.model = pickle.load(buffer) + cloned_model.model_class = self.__baseModel.model_class + cloned_model.pickled_model = True + cloned_model.params = self.__baseModel.params + else: + for attribute, value in vars(self.__baseModel).items(): + setattr(cloned_model, attribute, value) + + return cloned_model + + def reset(self) -> None: + """ + Resets the singleton instance, allowing for reinitialization. + + This method clears the current base model, enabling the set_base_model method + to set a new base model. + """ + self.__baseModel = None + + @property + def threshold(self): + if hasattr(BaseModelManager.__getattribute__(self, "_BaseModelManager__baseModel"), "threshold"): + return self.__baseModel.threshold + return self._threshold + + def get_info(self) -> Dict[str, Any]: + """ + Retrieves detailed information about the model. + + Returns: + Dict[str, Any]: A dictionary containing information about the model's type, parameters, + data preparation strategy, and whether it's a pickled model. + """ + if callable(getattr(self.__baseModel, "get_info", None)): + return self.__baseModel.get_info() + + return { + "model": self.__baseModel.__class__.__name__, + "model_type": self.__baseModel.__class__.__name__, + "params": self.__baseModel.get_params(), + "data_preparation_strategy": None, + "pickled_model": getattr(self.__baseModel, "pickled_model", False), + "file_path": getattr(self.__baseModel, "file_path", "") + } + + def __getattr__(self, name) -> Any: + """ + Returns attributes of the baseModel instance. + + Returns: + Any: The attributes of the baseModel instance. + """ + # Delegate attribute access to self.model + return getattr(self.__baseModel, name) diff --git a/MED3pa/models/classification_metrics.py b/MED3pa/models/classification_metrics.py new file mode 100644 index 0000000..61eeeeb --- /dev/null +++ b/MED3pa/models/classification_metrics.py @@ -0,0 +1,317 @@ +""" +The ``classification_metrics.py`` module defines the ``ClassificationEvaluationMetrics`` class, +that contains various classification metrics that can be used to assess the model's performance. +""" + +import numpy as np +import warnings +from collections.abc import Callable +from sklearn.metrics import (accuracy_score, average_precision_score, f1_score, log_loss, matthews_corrcoef, + precision_score, recall_score, roc_auc_score) +from typing import List, Optional + +from .abstract_metrics import EvaluationMetric + + +class ClassificationEvaluationMetrics(EvaluationMetric): + """ + A class to compute various classification evaluation metrics. + """ + + @staticmethod + def accuracy(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None, **kwargs) -> Optional[float]: + """ + Calculate the accuracy score. + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Accuracy score. + """ + if y_true.size == 0 or y_pred.size == 0: + return None + return accuracy_score(y_true, y_pred, sample_weight=sample_weight) + + @staticmethod + def recall(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None, **kwargs) -> Optional[float]: + """ + Calculate the recall score. + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Recall score. + """ + if y_true.size == 0 or y_pred.size == 0 or len(np.unique(y_true)) == 1: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return recall_score(y_true, y_pred, sample_weight=sample_weight, zero_division=0) + + @staticmethod + def roc_auc(y_true: np.ndarray, y_prob: np.ndarray, sample_weight: np.ndarray = None, **kwargs) -> Optional[float]: + """ + Calculate the ROC AUC score. + + Args: + y_true (np.ndarray): True labels. + y_prob (np.ndarray): Predicted probabilities. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: ROC AUC score. + """ + if y_true.size == 0 or y_prob.size == 0 or len(np.unique(y_true)) == 1: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return roc_auc_score(y_true, y_prob, sample_weight=sample_weight) + + @staticmethod + def average_precision(y_true: np.ndarray, y_prob: np.ndarray, sample_weight: np.ndarray = None, **kwargs + ) -> Optional[float]: + """ + Calculate the average precision score. + + Args: + y_true (np.ndarray): True labels. + y_prob (np.ndarray): Predicted probabilities. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Average precision score. + """ + if y_true.size == 0 or y_prob.size == 0 or len(np.unique(y_true)) == 1: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return average_precision_score(y_true, y_prob, sample_weight=sample_weight) + + @staticmethod + def matthews_corrcoef(y_true: np.ndarray, y_pred: np.ndarray, **kwargs + ) -> Optional[float]: + """ + Calculate the Matthews correlation coefficient. + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + + Returns: + float: Matthews correlation coefficient. + """ + if y_true.size == 0 or y_pred.size == 0 or len(np.unique(y_true)) == 1: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return matthews_corrcoef(y_true, y_pred) + + @staticmethod + def precision(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None, **kwargs + ) -> Optional[float]: + """ + Calculate the precision score. + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Precision score. + """ + if y_true.size == 0 or y_pred.size == 0 or len(np.unique(y_true)) == 1: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return precision_score(y_true, y_pred, sample_weight=sample_weight, zero_division=0) + + @staticmethod + def f1_score(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None, **kwargs) -> Optional[float]: + """ + Calculate the F1 score. + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: F1 score. + """ + if y_true.size == 0 or y_pred.size == 0 or len(np.unique(y_true)) == 1: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return f1_score(y_true, y_pred, sample_weight=sample_weight, zero_division=0) + + @staticmethod + def sensitivity(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None, **kwargs + ) -> Optional[float]: + """ + Calculate the sensitivity (recall for the positive class). + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Sensitivity score. + """ + if y_true.size == 0 or y_pred.size == 0 or len(np.unique(y_true)) == 1: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return recall_score(y_true, y_pred, pos_label=1, sample_weight=sample_weight, zero_division=0) + + @staticmethod + def specificity(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None, **kwargs + ) -> Optional[float]: + """ + Calculate the specificity (recall for the negative class). + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Specificity score. + """ + if y_true.size == 0 or y_pred.size == 0 or len(np.unique(y_true)) == 1: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return recall_score(y_true, y_pred, pos_label=0, sample_weight=sample_weight, zero_division=0) + + @staticmethod + def ppv(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None, **kwargs) -> Optional[float]: + """ + Calculate the positive predictive value (PPV). + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Positive predictive value. + """ + if y_true.size == 0 or y_pred.size == 0: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return precision_score(y_true, y_pred, pos_label=1, sample_weight=sample_weight, zero_division=0) + + @staticmethod + def npv(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None, **kwargs) -> Optional[float]: + """ + Calculate the negative predictive value (NPV). + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Negative predictive value. + """ + if y_true.size == 0 or y_pred.size == 0: + return None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return precision_score(y_true, y_pred, pos_label=0, sample_weight=sample_weight, zero_division=0) + + @staticmethod + def balanced_accuracy(y_true: np.ndarray, y_pred: np.ndarray, **kwargs) -> Optional[float]: + """ + Calculate the balanced accuracy score. + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels. + + Returns: + float: Balanced accuracy score. + """ + if y_true.size == 0 or y_pred.size == 0 or len(np.unique(y_true)) == 1: + return None + sens = ClassificationEvaluationMetrics.sensitivity(y_true, y_pred) + spec = ClassificationEvaluationMetrics.specificity(y_true, y_pred) + if sens is not None and spec is not None: + return (sens + spec) / 2 + else: + return None + + @staticmethod + def log_loss(y_true: np.ndarray, y_prob: np.ndarray, sample_weight: np.ndarray = None, **kwargs) -> Optional[float]: + """ + Calculate the log loss score. + + Args: + y_true (np.ndarray): True labels. + y_prob (np.ndarray): Predicted probabilities. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Log loss score. + """ + # return np.random.rand() + if y_true.size == 0 or y_prob.size == 0 or len(np.unique(y_true)) == 1: + return None + y_pred = np.clip(y_prob, 1e-15, 1 - 1e-15) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return log_loss(y_true, y_pred, sample_weight=sample_weight) + + @classmethod + def get_metric(cls, metric_name: str = '') -> Callable | List[str]: + """ + Get the metric function based on the metric name. + + Args: + metric_name (str): The name of the metric. + + Returns: function: The function corresponding to the metric, if metric_name is not ''. Otherwise returns the + list of supported classification metrics. + """ + metrics_mappings = { + 'Accuracy': cls.accuracy, + 'BalancedAccuracy': cls.balanced_accuracy, + 'Precision': cls.precision, + 'Recall': cls.recall, + 'F1Score': cls.f1_score, + 'Specificity': cls.specificity, + 'Sensitivity': cls.sensitivity, + 'Auc': cls.roc_auc, + 'LogLoss': cls.log_loss, + 'Auprc': cls.average_precision, + 'NPV': cls.npv, + 'PPV': cls.ppv, + 'MCC': cls.matthews_corrcoef + } + if metric_name == '': + return list(metrics_mappings) + else: + metric_function = metrics_mappings.get(metric_name) + if metric_function is None: + raise ValueError( + f"Metric '{metric_name}' is not recognized. Please choose from: {list(metrics_mappings)}") + return metric_function + + @classmethod + def supported_metrics(cls) -> List[str]: + """ + Get a list of supported classification metrics. + + Returns: + list: A list of supported classification metrics. + """ + return cls.get_metric() diff --git a/MED3pa/models/concrete_classifiers.py b/MED3pa/models/concrete_classifiers.py new file mode 100644 index 0000000..b4f781f --- /dev/null +++ b/MED3pa/models/concrete_classifiers.py @@ -0,0 +1,231 @@ +""" +This module offers concrete implementations of specific classification models, such as XGBoost. +It adapts the abstract interfaces defined in ``abstract_models.py`` to provide fully functional models ready for training and prediction. +""" + +import numpy as np +import xgboost as xgb +from typing import Any, Dict, Optional, Union + +from MED3pa.models.abstract_models import ClassificationModel +from MED3pa.models.data_strategies import ToDmatrixStrategy +from MED3pa.models.xgboost_params import valid_xgboost_custom_params, valid_xgboost_params + + +class XGBoostModel(ClassificationModel): + """ + A concrete implementation of the ClassificationModel class for XGBoost models. + This class provides functionalities to train, predict, and evaluate models built with the XGBoost library. + """ + def __init__(self, params: Optional[Dict[str, Any]] = None, + model: Optional[Union[xgb.Booster, xgb.XGBClassifier]] = None) -> None: + """ + Initializes the XGBoostModel either with parameters for a new model or a loaded pickled model. + + Args: + params (Optional[Dict[str, Any]]): A dictionary of parameters for the booster model. + model (Optional[Union[xgb.Booster, xgb.XGBClassifier]]): A loaded pickled model. + """ + super().__init__() + self.set_params(params) + # if the model is loaded from a pickled file + if model is not None: + self.set_model(model) + # if it's a new model, use xgb.Booster by default + else: + self.model_class = xgb.Booster + self.pickled_model = model is not None + self.set_data_strategy(ToDmatrixStrategy()) + + def _ensure_dmatrix(self, features: Any, labels: Optional[np.ndarray] = None, weights: Optional[np.ndarray] = None) -> xgb.DMatrix: + """ + Ensures that the input data is converted to a DMatrix format, using the defined data preparation strategy. + + Args: + features (Any): Features array. + labels (Optional[np.ndarray]): Labels array. + weights (Optional[np.ndarray]): Weights array. + + Returns: + xgb.DMatrix: A DMatrix object. + """ + if not isinstance(features, xgb.DMatrix): + return self.data_preparation_strategy.execute(features, labels, weights) + else: + return features + + def train(self, x_train: np.ndarray, y_train: np.ndarray, x_validation: np.ndarray, y_validation: np.ndarray, + training_parameters: Optional[Dict[str, Any]], balance_train_classes: bool) -> None: + """ + Trains the model on the provided dataset. + + Args: + x_train (np.ndarray): Features for training. + y_train (np.ndarray): Labels for training. + x_validation (np.ndarray): Features for validation. + y_validation (np.ndarray): Labels for validation. + training_parameters (Optional[Dict[str, Any]]): Additional training parameters. + balance_train_classes (bool): Whether to balance the training classes. + + Raises: + ValueError: If parameters for xgb.Booster are not initialized before training. + NotImplementedError: If the model_class is not supported for training. + """ + # if additional training parameters are provided + if training_parameters: + # ensure the provided training parameters are valid + valid_param_sets = [valid_xgboost_params, valid_xgboost_custom_params] + valid_training_params = self.validate_params(training_parameters, valid_param_sets) + # update the model's params with the validated training params + if self.params is not None : + self.params.update(valid_training_params) + else: + self.params = valid_training_params + weights = self.balance_train_weights(y_train) if balance_train_classes else training_parameters.get('training_weights', np.ones_like(y_train)) + evaluation_metrics = training_parameters.get('custom_eval_metrics', self.params.get('eval_metric', ["Accuracy"]) if self.params else ["Accuracy"]) + num_boost_rounds = training_parameters.get('num_boost_rounds', self.params.get('num_boost_rounds', 10) if self.params else 10) + + # if not additional training parameters were provided, use the model's params + else: + weights = self.balance_train_weights(y_train) if balance_train_classes else np.ones_like(y_train) + evaluation_metrics = self.params.get('eval_metric', ["Accuracy"]) if self.params else ["Accuracy"] + num_boost_rounds = self.params.get('num_boost_rounds', 10) if self.params else 10 + + if not self.params: + raise ValueError("Parameters must be initialized before training.") + + # take off the custom training parameters + filtered_params = {k: v for k, v in self.params.items() if k not in valid_xgboost_custom_params} + + if self.model_class is xgb.Booster: + # train the xgb.Booster + dtrain = self._ensure_dmatrix(x_train, y_train, weights) + dval = self._ensure_dmatrix(x_validation, y_validation) + self.model = xgb.train(filtered_params, dtrain, num_boost_round=num_boost_rounds, evals=[(dval, 'eval')], verbose_eval=False) + self.evaluate(x_validation, y_validation, eval_metrics=evaluation_metrics, print_results=True) + elif self.model_class is xgb.XGBClassifier: + # train the xgb.XGBClassifier + self.model = self.model_class(**filtered_params) + self.model.fit(x_train, y_train, sample_weight=weights, eval_set=[(x_validation, y_validation)]) + self.evaluate(x_validation, y_validation, eval_metrics=evaluation_metrics, print_results=True) + else: + # if the class is neither xgb.XGBClassifier or xgb.Booster, raise an error. + raise NotImplementedError(f"Training not implemented for model class {self.model_class}") + + def predict(self, X: np.ndarray, return_proba: bool = False, threshold: float = 0.5) -> np.ndarray: + """ + Makes predictions using the model for the given input. + + Args: + X (np.ndarray): Features for prediction. + return_proba (bool, optional): Whether to return probabilities. Defaults to False. + threshold (float, optional): Threshold for converting probabilities to class labels. Defaults to 0.5. + + Returns: + np.ndarray: Predictions made by the model. + + Raises: + ValueError: If the model has not been initialized. + NotImplementedError: If prediction is not implemented for the model class. + """ + if self.model is None: + raise ValueError(f"The {self.model_class.__name__} model has not been initialized.") + + if self.model_class is xgb.Booster: + dtest = self._ensure_dmatrix(X) + preds = self.model.predict(dtest) + elif self.model_class is xgb.XGBClassifier: + preds = self.model.predict_proba(X) if return_proba else self.model.predict(X) + else: + raise NotImplementedError(f"Prediction not implemented for model class {self.model_class}") + + return preds if return_proba else (preds > threshold).astype(int) + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """ + Makes predictions using the model for the given input. + + Args: + X (np.ndarray): Features for prediction. + + Returns: + np.ndarray: Predictions made by the model. + + Raises: + ValueError: If the model has not been initialized. + NotImplementedError: If prediction is not implemented for the model class. + """ + if self.model is None: + raise ValueError(f"The {self.model_class.__name__} model has not been initialized.") + + if self.model_class is xgb.Booster: + dtest = self._ensure_dmatrix(X) + preds = self.model.predict(dtest) + elif self.model_class is xgb.XGBClassifier: + preds = self.model.predict_proba(X) + else: + raise NotImplementedError(f"Prediction not implemented for model class {self.model_class}") + + return preds + + def train_to_disagree(self, x_train: np.ndarray, y_train: np.ndarray, + x_validation: np.ndarray, y_validation: np.ndarray, + x_test: np.ndarray, y_test: np.ndarray, + training_parameters: Optional[Dict[str, Any]], + balance_train_classes: bool, n_samples: int) -> None: + """ + Trains the model to disagree with another model using a specified dataset. + + This method is intended for scenarios where the model is trained to produce outputs that + intentionally diverge from those of another model, to be used in the ``detectron`` method + + Args: + x_train (np.ndarray): Features for training. + y_train (np.ndarray): Labels for training. + x_validation (np.ndarray): Features for validation. + y_validation (np.ndarray): Labels for validation. + x_test (np.ndarray): Features for testing or disagreement evaluation. + y_test (np.ndarray): Labels for testing or disagreement evaluation. + training_parameters (Optional[Dict[str, Any]]): Additional parameters for training the model. + balance_train_classes (bool): Whether to balance the class distribution in the training data. + n_samples (int): The number of examples in the testing set that should be used for calculating disagreement. + + Raises: + ValueError: If the necessary parameters for training are not properly initialized. + NotImplementedError: If the model class does not support this type of training. + """ + if training_parameters: + valid_param_sets = [valid_xgboost_params, valid_xgboost_custom_params] + valid_training_params = self.validate_params(training_parameters, valid_param_sets) + if self.params is not None: + self.params.update(valid_training_params) + else: + self.params = valid_training_params + training_weights = self.balance_train_weights(y_train) if balance_train_classes else training_parameters.get('training_weights', np.ones_like(y_train)) + evaluation_metrics = training_parameters.get('custom_eval_metrics', self.params.get('eval_metric', ["Accuracy"]) if self.params else ["Accuracy"]) + else: + training_weights = np.ones_like(y_train) + evaluation_metrics = self.params.get('eval_metric', ["Accuracy"]) if self.params else ["Accuracy"] + + if not self.params: + raise ValueError("Parameters must be initialized before training.") + + filtered_params = {k: v for k, v in self.params.items() if k not in valid_xgboost_custom_params} + + # prepare the data for training + data = np.concatenate([x_train, x_test]) + label = np.concatenate([y_train, 1 - y_test]) + weight = np.concatenate([training_weights, 1 / (n_samples + 1) * np.ones(n_samples)]) + + if self.model_class is xgb.Booster: + dtrain = self._ensure_dmatrix(data, label, weight) + dval = self._ensure_dmatrix(x_validation, y_validation) + self.model = xgb.train(filtered_params, dtrain, num_boost_round=10, evals=[(dval, 'eval')], verbose_eval=False) + self.evaluate(x_validation, y_validation, eval_metrics=evaluation_metrics) + elif self.model_class is xgb.XGBClassifier: + self.model = self.model_class(**filtered_params) + self.model.fit(data, label, sample_weight=weight, eval_set=[(x_validation, y_validation)]) + self.evaluate(x_validation, y_validation, eval_metrics=evaluation_metrics) + else: + raise NotImplementedError(f"Training not implemented for model class {self.model_class}") + \ No newline at end of file diff --git a/MED3pa/models/concrete_regressors.py b/MED3pa/models/concrete_regressors.py new file mode 100644 index 0000000..d49db1a --- /dev/null +++ b/MED3pa/models/concrete_regressors.py @@ -0,0 +1,316 @@ +"""Similar to ``concrete_classifiers.py``, this module contains implementations of regression models like +RandomForestRegressor and DecisionTreeRegressor. It provides practical, ready-to-use models that comply with the +abstract definitions, making it easier to integrate and use these models in ``med3pa``.""" + +import numpy as np +from copy import deepcopy +from sklearn.ensemble import RandomForestRegressor +from sklearn.tree import DecisionTreeRegressor +from sklearn.utils import resample +from typing import Any, Dict, Optional + +from .abstract_models import RegressionModel +from .data_strategies import ToNumpyStrategy +from .regression_metrics import * + + +class RandomForestRegressorModel(RegressionModel): + """ + A concrete implementation of the Model class for RandomForestRegressor models. + """ + def __init__(self, params: Dict[str, Any]) -> None: + """ + Initializes the RandomForestRegressorModel with a scikit-learn RandomForestRegressor. + + Args: + params (dict): Parameters for initializing the RandomForestRegressor. + """ + super().__init__() + self.params = params + self.model = RandomForestRegressor(**params) + self.model_class = RandomForestRegressor + self.pickled_model = False + self.data_preparation_strategy = ToNumpyStrategy() + + def train(self, x_train: np.ndarray, y_train: np.ndarray, x_validation: np.ndarray = None, + y_validation: np.ndarray = None, training_parameters: Optional[Dict[str, Any]] = None, **params) -> None: + """ + Trains the model on the provided dataset. + + Args: + x_train (np.ndarray): observations for training. + y_train (np.ndarray): Labels for training. + x_validation (np.ndarray, optional): observations for validation. + y_validation (np.ndarray, optional): Labels for validation. + training_parameters (dict, optional): Additional training parameters. + + Raises: + ValueError: If the RandomForestRegressorModel has not been initialized before training. + """ + if self.model is None: + raise ValueError("The RandomForestRegressor has not been initialized.") + + np_X_train, np_y_train = self._ensure_numpy_arrays(x_train, y_train) + + if training_parameters: + valid_param_sets = [set(self.model.get_params())] + validated_params = self.validate_params(training_parameters, valid_param_sets) + self.params.update(validated_params) + self.model.set_params(**self.params) + + self.model.fit(np_X_train, np_y_train, **params) + + if x_validation is not None and y_validation is not None: + self.evaluate(x_validation, y_validation, ['RMSE', 'MSE'], True) + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Makes predictions with the model for the given input. + + Args: + X (np.ndarray): observations for prediction. + + Returns: + np.ndarray: Predictions made by the model. + + Raises: + ValueError: If the RandomForestRegressorModel has not been initialized before training. + """ + if self.model is None: + raise ValueError("The RandomForestRegressorModel has not been initialized.") + else: + np_X, _ = self._ensure_numpy_arrays(X) + return self.model.predict(np_X) + + +class EnsembleRandomForestRegressorModel(RegressionModel): + """ + An ensemble model consisting of multiple RandomForestRegressorModel instances, + with undersampling applied to the majority class. + """ + + def __init__(self, params: Dict[str, Any] = None, + base_model: RandomForestRegressorModel = RandomForestRegressorModel, + n_models: int = 10, + random_state: int = None, + **params_sklearn) -> None: + """ + Initializes the EnsembleRandomForestRegressorModel with multiple RandomForestRegressor models. + + Args: + params (Dict[str, Any]): A dictionary of parameters for each model in the ensemble. + base_model (RandomForestRegressorModel): A prototype instance of RandomForestRegressorModel. + n_models (int): The number of RandomForestRegressorModel instances in the ensemble. + random_state (int): A random_state can be set for reproducibility. + **params_sklearn (Any) : Parameters for the sklearn model. + """ + if params is None: + params = {} + if params_sklearn is not None: + params.update(params_sklearn) + super().__init__() + self.params = params + self.n_models = n_models + self.models = [] + self.random_state = random_state + for n_model in range(n_models): + model = deepcopy(base_model(params)) + self.models.append(model) + self.model = self + + self.fit = self.train + + def _undersample(self, x: np.ndarray, y: np.ndarray, sample_weight: np.ndarray) -> tuple: + """ + Applies undersampling to the majority class based on sample weights. + Samples with lower sample weights are undersampled, while higher weighted samples are retained. + """ + # Sort data by sample_weight + sorted_indices = np.argsort(sample_weight) + x_sorted = x[sorted_indices] + y_sorted = y[sorted_indices] + sample_weight_sorted = sample_weight[sorted_indices] + + # Identify the threshold to differentiate between lower and higher sample weights + weight_threshold = np.median(sample_weight_sorted) + + # Split into "higher weight" and "lower weight" groups + x_higher_weight = x_sorted[sample_weight_sorted >= weight_threshold] + y_higher_weight = y_sorted[sample_weight_sorted >= weight_threshold] + sample_weight_higher = sample_weight_sorted[sample_weight_sorted >= weight_threshold] + + x_lower_weight = x_sorted[sample_weight_sorted < weight_threshold] + y_lower_weight = y_sorted[sample_weight_sorted < weight_threshold] + sample_weight_lower = sample_weight_sorted[sample_weight_sorted < weight_threshold] + + # Undersample the lower-weight group + if len(y_lower_weight) > len(y_higher_weight): + x_lower_weight_resampled, y_lower_weight_resampled, sample_weight_lower_resampled = resample( + x_lower_weight, y_lower_weight, sample_weight_lower, + replace=False, n_samples=len(y_higher_weight), + random_state=self.random_state + ) + else: + x_lower_weight_resampled, y_lower_weight_resampled, sample_weight_lower_resampled = ( + x_lower_weight, y_lower_weight, sample_weight_lower + ) + + # Combine higher-weight samples and resampled lower-weight samples + x_resampled = np.vstack((x_higher_weight, x_lower_weight_resampled)) + y_resampled = np.hstack((y_higher_weight, y_lower_weight_resampled)) + sample_weight_resampled = np.hstack((sample_weight_higher, sample_weight_lower_resampled)) + + return x_resampled, y_resampled, sample_weight_resampled + + def __sklearn_clone__(self): + """ + Overwrites the sklearn clone function + """ + new_instance = deepcopy(self) + return new_instance + + def train(self, x_train: np.ndarray, y_train: np.ndarray, x_validation: np.ndarray = None, + y_validation: np.ndarray = None, training_parameters: Optional[Dict[str, Any]] = None, **params) -> None: + """ + Trains each model in the ensemble on a differently resampled dataset. + """ + np_X_train, np_y_train = self._ensure_numpy_arrays(x_train, y_train) + + if training_parameters: + self.params.update(training_parameters) + if "sample_weight" not in params: + raise ValueError("EnsembleRandomForestRegressorModel must be trained with a sample_weight parameter") + sample_weight = params["sample_weight"] + + for model in self.models: + # Resample the dataset for each model + x_resampled, y_resampled, sample_weight = self._undersample(np_X_train, np_y_train, sample_weight) + model.train(x_resampled, y_resampled, x_validation, y_validation, training_parameters, **params) + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Makes predictions with the ensemble model by averaging predictions from each model. + """ + np_X, _ = self._ensure_numpy_arrays(X) + + predictions = np.zeros((self.n_models, len(np_X))) + + for i, model in enumerate(self.models): + predictions[i] = model.predict(np_X) + + return np.mean(predictions, axis=0) + + def score(self, X, y, sample_weight=None): + """ Taken from sklearn.base.py + Return the coefficient of determination of the prediction. + + The coefficient of determination :math:`R^2` is defined as + :math:`(1 - \\frac{u}{v})`, where :math:`u` is the residual + sum of squares ``((y_true - y_pred)** 2).sum()`` and :math:`v` + is the total sum of squares ``((y_true - y_true.mean()) ** 2).sum()``. + The best possible score is 1.0 and it can be negative (because the + model can be arbitrarily worse). A constant model that always predicts + the expected value of `y`, disregarding the input features, would get + a :math:`R^2` score of 0.0. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Test samples. For some estimators this may be a precomputed + kernel matrix or a list of generic objects instead with shape + ``(n_samples, n_samples_fitted)``, where ``n_samples_fitted`` + is the number of samples used in the fitting for the estimator. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + True values for `X`. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + score : float + :math:`R^2` of ``self.predict(X)`` w.r.t. `y`. + + Notes + ----- + The :math:`R^2` score used when calling ``score`` on a regressor uses + ``multioutput='uniform_average'`` from version 0.23 to keep consistent + with default value of :func:`~sklearn.metrics.r2_score`. + This influences the ``score`` method of all the multioutput + regressors (except for + :class:`~sklearn.multioutput.MultiOutputRegressor`). + """ + + from sklearn.metrics import r2_score + + y_pred = self.predict(X) + return r2_score(y, y_pred, sample_weight=sample_weight) + + +class DecisionTreeRegressorModel(RegressionModel): + """ + A concrete implementation of the Model class for DecisionTree models. + """ + def __init__(self, params: Dict[str, Any]) -> None: + """ + Initializes the DecisionTreeRegressorModel with a scikit-learn DecisionTreeRegressor. + + Args: + params (dict): Parameters for initializing the DecisionTreeRegressor. + """ + super().__init__() + self.params = params + self.model = DecisionTreeRegressor(**params) + self.model_class = DecisionTreeRegressor + self.pickled_model = False + self.data_preparation_strategy = ToNumpyStrategy() + + def train(self, x_train: np.ndarray, y_train: np.ndarray, x_validation: np.ndarray = None, y_validation: np.ndarray = None, training_parameters: Optional[Dict[str, Any]] = None) -> None: + """ + Trains the model on the provided dataset. + + Args: + x_train (np.ndarray): observations for training. + y_train (np.ndarray): Targets for training. + x_validation (np.ndarray, optional): observations for validation. + y_validation (np.ndarray, optional): Targets for validation. + training_parameters (dict, optional): Additional training parameters. + + Raises: + ValueError: If the DecisionTreeRegressorModel has not been initialized before training. + """ + if self.model is None: + raise ValueError("The DecisionTreeRegressorModel has not been initialized.") + + np_X_train, np_y_train = self._ensure_numpy_arrays(x_train, y_train) + + if training_parameters: + valid_param_sets = [set(self.model.get_params())] + validated_params = self.validate_params(training_parameters, valid_param_sets) + self.params.update(validated_params) + self.model.set_params(**self.params) + + self.model.fit(np_X_train, np_y_train) + + if x_validation is not None and y_validation is not None: + self.evaluate(x_validation, y_validation, ['RMSE', 'MSE'], True) + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Makes predictions with the model for the given input. + + Args: + X (np.ndarray): observations for prediction. + + Returns: + np.ndarray: Predictions made by the model. + + Raises: + ValueError: If the DecisionTreeRegressorModel has not been initialized before training. + """ + if self.model is None: + raise ValueError("The DecisionTreeRegressorModel has not been initialized.") + else: + np_X, _ = self._ensure_numpy_arrays(X) + return self.model.predict(np_X) diff --git a/MED3pa/models/data_strategies.py b/MED3pa/models/data_strategies.py new file mode 100644 index 0000000..874a56e --- /dev/null +++ b/MED3pa/models/data_strategies.py @@ -0,0 +1,148 @@ +""" +This module is crucial for data handling, utilizing the **Strategy design pattern** and therefore offering multiple strategies to transform raw data into formats that enhance model training and evaluation. +According to the model type. +""" + +import numpy as np +import pandas as pd +import scipy.sparse as sp +import xgboost as xgb +from typing import Optional + + +class DataPreparingStrategy: + """ + Abstract base class for data preparation strategies. + """ + + @staticmethod + def execute(observations: np.ndarray, labels: Optional[np.ndarray] = None, weights: Optional[np.ndarray] = None + ) -> tuple: + """ + Prepares data for model training or prediction. + + Args: + observations (array-like): observations array. + labels (array-like, optional): Labels array. + weights (array-like, optional): Weights array. + + Returns: + object: Prepared data in the required format for the model. + + Raises: + NotImplementedError: If the method is not implemented by a subclass. + """ + raise NotImplementedError("Subclasses must implement this method.") + + +class ToDmatrixStrategy(DataPreparingStrategy): + """ + Concrete implementation for converting data into DMatrix format suitable for XGBoost models. + """ + + @staticmethod + def is_supported_data(observations: np.ndarray, labels: Optional[np.ndarray] = None, + weights: Optional[np.ndarray] = None) -> bool: + """ + Checks if the data types of observations, labels, and weights are supported for conversion to DMatrix. + + Args: + observations (array-like): observations data. + labels (array-like, optional): Labels data. + weights (array-like, optional): Weights data. + + Returns: + bool: True if all data types are supported, False otherwise. + """ + supported_types = [np.ndarray, pd.DataFrame, sp.spmatrix, list] + is_supported = lambda data: any(isinstance(data, t) for t in supported_types) + + return all(is_supported(data) for data in [observations, labels, weights] if data is not None) + + @staticmethod + def execute(observations: np.ndarray, labels: Optional[np.ndarray] = None, + weights: Optional[np.ndarray] = None) -> xgb.DMatrix: + """ + Converts observations, labels, and weights into an XGBoost DMatrix. + + Args: + observations (array-like): observations data. + labels (array-like, optional): Labels data. + weights (array-like, optional): Weights data. + + Returns: + xgb.DMatrix: A DMatrix object ready for use with XGBoost. + + Raises: + ValueError: If any input data types are not supported. + """ + if not ToDmatrixStrategy.is_supported_data(observations, labels, weights): + raise ValueError("Unsupported data type provided for creating DMatrix.") + return xgb.DMatrix(data=observations, label=labels, weight=weights) + + +class ToNumpyStrategy(DataPreparingStrategy): + """ + Converts input data to NumPy arrays, ensuring compatibility with models expecting NumPy inputs. + """ + + @staticmethod + def execute(observations: np.ndarray, labels: Optional[np.ndarray] = None, + weights: Optional[np.ndarray] = None) -> tuple: + """ + Converts observations, labels, and weights into NumPy arrays. + + Args: + observations (array-like): observations data. + labels (array-like, optional): Labels data. + weights (array-like, optional): Weights data. + + Returns: + tuple: A tuple of NumPy arrays for observations, labels, and weights. Returns None for labels and weights if they are not provided. + + Raises: + ValueError: If the observations or labels are empty arrays. + """ + obs_np = np.asarray(observations) + labels_np = np.asarray(labels) if labels is not None else None + weights_np = np.asarray(weights) if weights is not None else None + + if obs_np.size == 0: + raise ValueError("Observations array cannot be empty.") + if labels is not None and labels_np.size == 0: + raise ValueError("Labels array cannot be empty.") + + return obs_np, labels_np, weights_np + + +class ToDataframesStrategy(DataPreparingStrategy): + """ + Converts input data to pandas DataFrames, suitable for models requiring DataFrame inputs. + """ + + @staticmethod + def execute(column_labels: list, observations: np.ndarray, labels: np.ndarray = None, + weights: np.ndarray = None) -> tuple: + """ + Converts observations, labels, and weights into pandas DataFrames with specified column labels. + + Args: + column_labels (list): Column labels for the observations DataFrame. + observations (np.ndarray): observations array. + labels (np.ndarray, optional): Labels array. + weights (np.ndarray, optional): Weights array. + + Returns: + tuple: DataFrames for observations, labels, and weights. Returns None for labels and weights DataFrames if not provided. + + Raises: + ValueError: If the observations array is empty. + """ + if observations.size == 0: + raise ValueError("observations array cannot be empty.") + + X_df = pd.DataFrame(observations, columns=column_labels) + Y_df = pd.DataFrame(labels, columns=['Label']) if labels is not None else None + W_df = pd.DataFrame(weights, columns=['Weight']) if weights is not None else None + + return X_df, Y_df, W_df diff --git a/MED3pa/models/dtr_params.py b/MED3pa/models/dtr_params.py new file mode 100644 index 0000000..8772a1c --- /dev/null +++ b/MED3pa/models/dtr_params.py @@ -0,0 +1,25 @@ +dtr_params = [ + {"name": "criterion", "type": "string", "default": "squared_error", "choices": ["squared_error", "friedman_mse", "absolute_error", "poisson"], "description": "The function to measure the quality of a split."}, + {"name": "splitter", "type": "string", "default": "best", "choices": ["best", "random"], "description": "The strategy used to choose the split at each node."}, + {"name": "max_depth", "type": "int", "default": 4, "description": "The maximum depth of the tree. Increasing this value will make the model more complex."}, + {"name": "min_samples_split", "type": "int", "default": 2, "description": "The minimum number of samples required to split an internal node."}, + {"name": "min_samples_leaf", "type": "int", "default": 1, "description": "The minimum number of samples required to be at a leaf node."}, + {"name": "min_weight_fraction_leaf", "type": "float", "default": 0.0, "description": "The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node."}, + {"name": "max_features", "type": "string", "default": "auto", "choices": ["auto", "sqrt", "log2"], "description": "The number of features to consider when looking for the best split."}, + {"name": "max_leaf_nodes", "type": "int", "default": 100, "description": "Grow trees with max_leaf_nodes in best-first fashion. Best nodes are defined as relative reduction in impurity."}, + {"name": "min_impurity_decrease", "type": "float", "default": 0.0, "description": "A node will be split if this split induces a decrease of the impurity greater than or equal to this value."}, + {"name": "ccp_alpha", "type": "float", "default": 0.0, "description": "Complexity parameter used for Minimal Cost-Complexity Pruning. The subtree with the largest cost complexity that is smaller than ccp_alpha will be chosen."} +] + +dtr_gridsearch_params = [ + {"name": "criterion", "type": "string", "default": ["squared_error", "friedman_mse", "absolute_error", "poisson"], "description": "The function to measure the quality of a split."}, + {"name": "splitter", "type": "string", "default": ["best", "random"], "description": "The strategy used to choose the split at each node."}, + {"name": "max_depth", "type": "int", "default": [2, 3, 4, 5, 6], "description": "The maximum depth of the tree. Increasing this value will make the model more complex."}, + {"name": "min_samples_split", "type": "int", "default": [2, 5, 10], "description": "The minimum number of samples required to split an internal node."}, + {"name": "min_samples_leaf", "type": "int", "default": [1, 2, 4], "description": "The minimum number of samples required to be at a leaf node."}, + {"name": "min_weight_fraction_leaf", "type": "float", "default": [0.0, 0.1, 0.2], "description": "The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node."}, + {"name": "max_features", "type": "string", "default": ["auto", "sqrt", "log2"], "description": "The number of features to consider when looking for the best split."}, + {"name": "max_leaf_nodes", "type": "int", "default": [10, 20, 30], "description": "Grow trees with max_leaf_nodes in best-first fashion. Best nodes are defined as relative reduction in impurity."}, + {"name": "min_impurity_decrease", "type": "float", "default": [0.0, 0.1, 0.2], "description": "A node will be split if this split induces a decrease of the impurity greater than or equal to this value."}, + {"name": "ccp_alpha", "type": "float", "default": [0.0, 0.01, 0.1], "description": "Complexity parameter used for Minimal Cost-Complexity Pruning. The subtree with the largest cost complexity that is smaller than ccp_alpha will be chosen."} +] diff --git a/MED3pa/models/factories.py b/MED3pa/models/factories.py new file mode 100644 index 0000000..790094d --- /dev/null +++ b/MED3pa/models/factories.py @@ -0,0 +1,260 @@ +""" +This module utilizes the **Factory design pattern** to abstract the creation process of machine learning models. +It defines a general factory class and specialized factories for different model types, such as XGBoost. +This setup allows for dynamic model instantiation based on provided specifications or configurations. +By decoupling model creation from usage""" + +import json +import pickle +import re +import warnings +import xgboost as xgb +from typing import Union + +from .abstract_models import Model +from .concrete_classifiers import XGBoostModel + + +class ModelFactory: + """ + A factory class for creating models with different types, using the factory design pattern. + It supports creating models based on hyperparameters or loading them from pickled files. + """ + + model_mapping = { + 'XGBoostModel': [xgb.Booster, xgb.XGBClassifier], + } + + factories = { + 'XGBoostModel': lambda: XGBoostFactory(), + } + + @staticmethod + def get_factory(model_type: str) -> 'ModelFactory': + """ + Retrieves the factory object for the given model type. + + Args: + model_type (str): The type of model for which the factory is to be retrieved. + + Returns: + ModelFactory: An instance of the factory associated with the given model type. + + Raises: + ValueError: If no factory is available for the given model type. + """ + factory_initializer = ModelFactory.factories.get(model_type) + if factory_initializer: + return factory_initializer() + else: + raise ValueError(f"No factory available for model type: {model_type}") + + @staticmethod + def get_supported_models() -> list: + """ + Retrieves a list of all supported model types. + + Returns: + list: A list containing the keys from model_mapping which represent the supported model types. + """ + return list(ModelFactory.model_mapping) + + @staticmethod + def create_model_with_hyperparams(model_type: str, hyperparams: dict) -> Model: + """ + Creates a model of the specified type with the given hyperparameters. + + Args: + model_type (str): The type of model to create. + hyperparams (dict): A dictionary of hyperparameters for the model. + + Returns: + Model: A model instance of the specified type, initialized with the given hyperparameters. + """ + factory = ModelFactory.get_factory(model_type) + return factory.create_model_with_hyperparams(hyperparams) + + @staticmethod + def create_model_from_pickled(pickled_file_path: str) -> Model: + """ + Creates a model by loading it from a pickled file. + + Args: + pickled_file_path (str): The file path to the pickled model file. + + Returns: + Model: A model instance loaded from the pickled file. + + Raises: + IOError: If there is an error loading the model from the file. + TypeError: If the loaded model is not of a supported type. + """ + warnings.filterwarnings("ignore", message=r".*WARNING.*", category=UserWarning, module="xgboost.core") + try: + with open(pickled_file_path, 'rb') as file: + loaded_model = pickle.load(file) + except Exception as e: + raise IOError(f"Failed to load the model from {pickled_file_path}: {e}") + + for model_type, model_classes in ModelFactory.model_mapping.items(): + if any(isinstance(loaded_model, model_class) for model_class in model_classes): + factory = ModelFactory.get_factory(model_type) + return factory.create_model_from_pickled(pickled_file_path) + + raise TypeError("The loaded model is not of a supported type") + + +class XGBoostFactory(ModelFactory): + """ + A factory for creating XGBoost model objects, either from hyperparameters or by loading from pickled files. + Inherits from ModelFactory and specifies creation methods for XGBoost models. + """ + + def create_model_with_hyperparams(self, hyperparams: dict) -> XGBoostModel: + """ + Creates an XGBoostModel with the given hyperparameters. + + Args: + hyperparams (dict): A dictionary of hyperparameters for the XGBoost model. + + Returns: + XGBoostModel: An instance of XGBoostModel initialized with the given hyperparameters. + """ + return XGBoostModel(params=hyperparams) + + def create_model_from_pickled(self, pickled_file_path: str) -> XGBoostModel: + """ + Recreates an XGBoostModel from a loaded pickled model. + + Args: + pickled_file_path (str): The file path to the pickled model file. + + Returns: + XGBoostModel: An instance of XGBoostModel created from the loaded model. + + Raises: + IOError: If there is an error loading the model from the file. + TypeError: If the loaded model is not a supported implementation of the XGBoost model. + ValueError: If the XGBoost model version is not supported. + """ + warnings.filterwarnings("ignore", message=r".*WARNING.*", category=UserWarning, module="xgboost.core") + try: + with open(pickled_file_path, 'rb') as file: + loaded_model = pickle.load(file) + except Exception as e: + raise IOError(f"Failed to load the model from {pickled_file_path}: {e}") + + if isinstance(loaded_model, (xgb.Booster, xgb.XGBClassifier)): + if self.check_version(loaded_model): + extracted_params = self.extract_params(loaded_model) + xgb_model = XGBoostModel(params=extracted_params, model=loaded_model) + xgb_model.set_file_path(file_path=pickled_file_path) + return xgb_model + else: + raise ValueError("XGBoost model version is not supported. Please use version 2.0.0 or later.") + else: + raise TypeError("Loaded model is not an XGBoost model") + + def check_version(self, loaded_model: Union[xgb.Booster, xgb.XGBClassifier]) -> bool: + """ + Checks the version of the loaded XGBoost model to ensure it is supported. + + Args: + loaded_model (xgb.Booster | xgb.XGBClassifier): The loaded model object. + + Returns: + bool: True if the model version is supported, False otherwise. + """ + config_json = loaded_model.save_config() + config = json.loads(config_json) + version_list = config.get('version', 'Not available') + if isinstance(version_list, list): + version_str = '.'.join(map(str, version_list)) + else: + version_str = version_list + + version_match = re.match(r'(\d+)\.(\d+)\.(\d+)', version_str) + if version_match: + major, minor, patch = map(int, version_match.groups()) + return (major, minor, patch) >= (2, 0, 0) + else: + return False + + def extract_params(self, loaded_model: Union[xgb.Booster, xgb.XGBClassifier]) -> dict: + """ + Extracts the parameters from a loaded XGBoost model. + + Args: + loaded_model (xgb.Booster | xgb.XGBClassifier): The loaded model object. + + Returns: + dict: A dictionary of extracted parameters. + """ + try: + boosted_rounds = loaded_model.num_boosted_rounds() + config_json = loaded_model.save_config() + config = json.loads(config_json) + except AttributeError as e: + print(f"Error extracting configuration from model: {e}") + return {} + except json.JSONDecodeError as e: + print(f"Error decoding JSON configuration: {e}") + return {} + + try: + learner = config['learner'] + gradient_booster = learner['gradient_booster'] + + general_params = learner['generic_param'] + booster_params = gradient_booster.get('gbtree_train_param', {}) + tree_params = gradient_booster.get('tree_train_param', {}) + + updater_params = {} + if 'updater' in gradient_booster and isinstance(gradient_booster['updater'], list): + for updater in gradient_booster['updater']: + if 'hist_train_param' in updater: + updater_params.update(updater['hist_train_param']) + + learning_task_params = learner['learner_train_param'] + objective_params = learner['objective'].get('reg_loss_param', {}) + learner_model_params = learner['learner_model_param'] + + params = {} + params.update(general_params) + params.update(booster_params) + params.update(tree_params) + params.update(updater_params) + params.update(learning_task_params) + params.update(objective_params) + params.update(learner_model_params) + + if 'metrics' in learner: + metrics = [metric['name'] for metric in learner['metrics']] + if metrics: + params['eval_metric'] = metrics + params['num_boost_rounds'] = boosted_rounds + + for key, value in params.items(): + try: + if isinstance(value, str): + if '.' in value or 'E' in value or 'e' in value: + params[key] = float(value) + elif value.isdigit(): + params[key] = int(value) + else: + # Skip conversion for non-numeric strings + continue + else: + params[key] = int(value) + except (ValueError, TypeError) as e: + continue + + removable_keys = ['num_trees'] + for key in removable_keys: + params.pop(key, None) + + except KeyError as e: + print(f"Key error while extracting parameters: {e}") + return {} + + return params diff --git a/MED3pa/models/regression_metrics.py b/MED3pa/models/regression_metrics.py new file mode 100644 index 0000000..9955c96 --- /dev/null +++ b/MED3pa/models/regression_metrics.py @@ -0,0 +1,117 @@ +""" +The ``regression_metrics.py`` module defines the ``RegressionEvaluationMetrics`` class, +that contains various regression metrics that can be used to assess the model's performance. +""" +import numpy as np +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score +from typing import Callable, List + +from .abstract_metrics import EvaluationMetric + + +class RegressionEvaluationMetrics(EvaluationMetric): + """ + A class to compute various regression evaluation metrics. + """ + + @staticmethod + def mean_squared_error(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None) -> float: + """ + Calculate the Mean Squared Error (MSE). + + Args: + y_true (np.ndarray): True values. + y_pred (np.ndarray): Predicted values. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Mean Squared Error. + """ + if y_true.size == 0 or y_pred.size == 0: + return None + return mean_squared_error(y_true, y_pred, sample_weight=sample_weight) + + @staticmethod + def root_mean_squared_error(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None) -> float: + """ + Calculate the Root Mean Squared Error (RMSE). + + Args: + y_true (np.ndarray): True values. + y_pred (np.ndarray): Predicted values. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Root Mean Squared Error. + """ + if y_true.size == 0 or y_pred.size == 0: + return None + return np.sqrt(mean_squared_error(y_true, y_pred, sample_weight=sample_weight)) + + @staticmethod + def mean_absolute_error(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None) -> float: + """ + Calculate the Mean Absolute Error (MAE). + + Args: + y_true (np.ndarray): True values. + y_pred (np.ndarray): Predicted values. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: Mean Absolute Error. + """ + if y_true.size == 0 or y_pred.size == 0: + return None + return mean_absolute_error(y_true, y_pred, sample_weight=sample_weight) + + @staticmethod + def r2_score(y_true: np.ndarray, y_pred: np.ndarray, sample_weight: np.ndarray = None) -> float: + """ + Calculate the R-squared (R2) score. + + Args: + y_true (np.ndarray): True values. + y_pred (np.ndarray): Predicted values. + sample_weight (np.ndarray, optional): Sample weights. + + Returns: + float: R-squared score. + """ + if y_true.size == 0 or y_pred.size == 0: + return None + return r2_score(y_true, y_pred, sample_weight=sample_weight) + + @classmethod + def get_metric(cls, metric_name: str = '') -> Callable | List[str]: + """ + Get the metric function based on the metric name. + + Args: + metric_name (str): The name of the metric. + + Returns: + function: The function corresponding to the metric. + List: List of available metric names. + """ + metrics_mappings = { + 'MSE': cls.mean_squared_error, + 'RMSE': cls.root_mean_squared_error, + 'MAE': cls.mean_absolute_error, + 'R2': cls.r2_score + } + if metric_name == '': + return list(metrics_mappings) + else: + metric_function = metrics_mappings.get(metric_name) + return metric_function + + @classmethod + def supported_metrics(cls) -> List[str]: + """ + Get a list of supported classification metrics. + + Returns: + list: A list of supported classification metrics. + """ + return cls.get_metric() diff --git a/MED3pa/models/rfr_params.py b/MED3pa/models/rfr_params.py new file mode 100644 index 0000000..4edce8a --- /dev/null +++ b/MED3pa/models/rfr_params.py @@ -0,0 +1,28 @@ +rfr_params = [ + {"name": "n_estimators", "type": "int", "default": 100, "description": "The number of trees in the forest."}, + {"name": "criterion", "type": "string", "choices": ["squared_error", "absolute_error", "poisson"], "default": "squared_error", "description": "The function to measure the quality of a split."}, + {"name": "max_depth", "type": "int", "default": 4, "description": "The maximum depth of the tree. Increasing this value will make the model more complex."}, + {"name": "min_samples_split", "type": "int", "default": 2, "description": "The minimum number of samples required to split an internal node."}, + {"name": "min_samples_leaf", "type": "int", "default": 1, "description": "The minimum number of samples required to be at a leaf node."}, + {"name": "min_weight_fraction_leaf", "type": "float", "default": 0.0, "description": "The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node."}, + {"name": "max_features", "type": "string", "choices": ["auto", "sqrt", "log2"], "default": "auto", "description": "The number of features to consider when looking for the best split."}, + {"name": "max_leaf_nodes", "type": "int", "default": 100, "description": "Grow trees with max_leaf_nodes in best-first fashion. Best nodes are defined as relative reduction in impurity."}, + {"name": "min_impurity_decrease", "type": "float", "default": 0.0, "description": "A node will be split if this split induces a decrease of the impurity greater than or equal to this value."}, + {"name": "bootstrap", "type": "bool", "default": True, "description": "Whether bootstrap samples are used when building trees. If False, the whole dataset is used to build each tree."}, + {"name": "oob_score", "type": "bool", "default": False, "description": "Whether to use out-of-bag samples to estimate the generalization score."}, + {"name": "n_jobs", "type": "int", "default": 1, "description": "The number of jobs to run in parallel. None means 1 unless in a joblib.parallel_backend context."}, + {"name": "random_state", "type": "int", "default": 42, "description": "Controls both the randomness of the bootstrapping of the samples used when building trees (if bootstrap=True) and the sampling of the features to consider when looking for the best split at each node (if max_features < n_features)."}, + {"name": "verbose", "type": "int", "default": 0, "description": "Controls the verbosity when fitting and predicting."}, + {"name": "warm_start", "type": "bool", "default": False, "description": "When set to True, reuse the solution of the previous call to fit and add more estimators to the ensemble, otherwise, just fit a whole new forest."}, + {"name": "ccp_alpha", "type": "float", "default": 0.0, "description": "Complexity parameter used for Minimal Cost-Complexity Pruning. The subtree with the largest cost complexity that is smaller than ccp_alpha will be chosen."}, + {"name": "max_samples", "type": "float", "default": 0.9, "description": "If bootstrap is True, the number of samples to draw from X to train each base estimator."} +] + +rfr_gridsearch_params = [ + {"name": "n_estimators", "type": "int", "default": [100, 200, 300, 400, 500], "description": "The number of trees in the forest."}, + {"name": "max_depth", "type": "int", "default": [2, 3, 4, 5, 6], "description": "The maximum depth of the tree. Increasing this value will make the model more complex."}, + {"name": "min_samples_split", "type": "int", "default": [2, 5, 10], "description": "The minimum number of samples required to split an internal node."}, + {"name": "min_samples_leaf", "type": "int", "default": [1, 2, 4], "description": "The minimum number of samples required to be at a leaf node."}, + {"name": "max_features", "type": "string", "default": ["auto", "sqrt", "log2"], "description": "The number of features to consider when looking for the best split."}, + {"name": "bootstrap", "type": "bool", "default": [True, False], "description": "Whether bootstrap samples are used when building trees. If False, the whole dataset is used to build each tree."} +] diff --git a/MED3pa/models/xgboost_params.py b/MED3pa/models/xgboost_params.py new file mode 100644 index 0000000..02f10f1 --- /dev/null +++ b/MED3pa/models/xgboost_params.py @@ -0,0 +1,147 @@ +valid_xgboost_params = { + "booster", + "nthread", + "verbosity", + "validate_parameters", + "device", + "seed", + "seed_per_iteration", + "disable_default_eval_metric", + + # Tree Booster parameters + "eta", + "gamma", + "max_depth", + "min_child_weight", + "subsample", + "colsample_bytree", + "colsample_bylevel", + "colsample_bynode", + "lambda", + "alpha", + "tree_method", + "scale_pos_weight", + "max_bin", + "sampling_method", + "grow_policy", + "max_leaves", + "max_delta_step", + "refresh_leaf", + "process_type", + "updater", + "monotone_constraints", + "interaction_constraints", + + # DART Booster parameters + "sample_type", + "normalize_type", + "rate_drop", + "one_drop", + "skip_drop", + + # Linear Booster parameters + "feature_selector", + "top_k", + + # Learning task parameters + "objective", + "base_score", + "eval_metric", + + # Special objective parameters + "tweedie_variance_power", + "huber_slope", + "quantile_alpha", + "aft_loss_distribution", + + # Rank parameters + "lambdarank_pair_method", + "lambdarank_num_pair_per_sample", + "lambdarank_unbiased", + "lambdarank_bias_norm", + "ndcg_exp_gain", +} + +valid_xgboost_custom_params = { + # custom parameters + "custom_eval_metrics", + "num_boost_rounds", + "training_weights" +} + +xgboost_params = [ + {"name": "booster", "type": "string", "choices": ["gbtree", "gblinear", "dart"], "default": "gbtree"}, + {"name": "nthread", "type": "int", "default": None}, + {"name": "verbosity", "type": "int", "choices": [0, 1, 2, 3], "default": 1}, + {"name": "validate_parameters", "type": "bool", "default": False}, + {"name": "device", "type": "string", "choices": ["cpu", "cuda"], "default": "cpu"}, + {"name": "seed", "type": "int", "default": 0}, + {"name": "seed_per_iteration", "type": "bool", "default": 0}, + {"name": "disable_default_eval_metric", "type": "bool", "default": 0}, + + # Tree Booster parameters + {"name": "eta", "type": "float", "range": [0, 1], "default": 0.3}, + {"name": "gamma", "type": "float", "range": [0, float('inf')], "default": 0}, + {"name": "max_depth", "type": "int", "range": [0, float('inf')], "default": 6}, + {"name": "min_child_weight", "type": "float", "range": [0, float('inf')], "default": 1}, + {"name": "subsample", "type": "float", "range": [0, 1], "default": 1}, + {"name": "colsample_bytree", "type": "float", "range": [0, 1], "default": 1}, + {"name": "colsample_bylevel", "type": "float", "range": [0, 1], "default": 1}, + {"name": "colsample_bynode", "type": "float", "range": [0, 1], "default": 1}, + {"name": "lambda", "type": "float", "range": [0, float('inf')], "default": 1}, + {"name": "alpha", "type": "float", "range": [0, float('inf')], "default": 0}, + {"name": "tree_method", "type": "string", "choices": ["auto", "exact", "approx", "hist", "gpu_hist"], "default": "auto"}, + {"name": "scale_pos_weight", "type": "float", "default": 1}, + {"name": "max_bin", "type": "int", "default": 256}, + {"name": "sampling_method", "type": "string", "choices": ["uniform", "gradient_based"], "default": "uniform"}, + {"name": "grow_policy", "type": "string", "choices": ["depthwise", "lossguide"], "default": "depthwise"}, + {"name": "max_leaves", "type": "int", "default": 0}, + {"name": "max_delta_step", "type": "float", "default": 0}, + {"name": "refresh_leaf", "type": "int", "choices": [0, 1], "default": 1}, + {"name": "process_type", "type": "string", "choices": ["default", "update"], "default": "default"}, + {"name": "updater", "type": "string", "default": "grow_colmaker,prune"}, + {"name": "monotone_constraints", "type": "string", "default": "()"}, + {"name": "interaction_constraints", "type": "string", "default": ""}, + + # DART Booster parameters + {"name": "sample_type", "type": "string", "choices": ["uniform", "weighted"], "default": "uniform"}, + {"name": "normalize_type", "type": "string", "choices": ["tree", "forest"], "default": "tree"}, + {"name": "rate_drop", "type": "float", "range": [0, 1], "default": 0.0}, + {"name": "one_drop", "type": "bool", "default": 0}, + {"name": "skip_drop", "type": "float", "range": [0, 1], "default": 0.0}, + + # Linear Booster parameters + {"name": "updater", "type": "string", "choices": ["shotgun", "coord_descent"], "default": "shotgun"}, + {"name": "feature_selector", "type": "string", "choices": ["cyclic", "shuffle", "random", "greedy", "thrifty"], "default": "cyclic"}, + {"name": "top_k", "type": "int", "default": 0}, + + # Learning task parameters + {"name": "objective", "type": "string", "choices": [ + "reg:squarederror", "reg:squaredlogerror", "reg:logistic", "reg:pseudohubererror", + "reg:absoluteerror", "reg:quantileerror", "binary:logistic", "binary:logitraw", + "binary:hinge", "count:poisson", "survival:cox", "survival:aft", "multi:softmax", + "multi:softprob", "rank:ndcg", "rank:map", "rank:pairwise", "reg:gamma", "reg:tweedie" + ], "default": "reg:squarederror"}, + {"name": "base_score", "type": "float", "default": 0.5}, + {"name": "eval_metric", "type": "string", "default": "rmse"}, + + # Special objective parameters + {"name": "tweedie_variance_power", "type": "float", "range": [1, 2], "default": 1.5}, + {"name": "huber_slope", "type": "float", "default": 1.0}, + {"name": "quantile_alpha", "type": "float", "default": 0.5}, + {"name": "aft_loss_distribution", "type": "string", "choices": ["normal", "logistic", "extreme"], "default": "normal"}, + + # Rank parameters + {"name": "lambdarank_pair_method", "type": "string", "choices": ["mean", "topk"], "default": "mean"}, + {"name": "lambdarank_num_pair_per_sample", "type": "int", "default": 1}, + {"name": "lambdarank_unbiased", "type": "bool", "default": False}, + {"name": "lambdarank_bias_norm", "type": "float", "default": 2.0}, + {"name": "ndcg_exp_gain", "type": "bool", "default": True} +] + +# Translation mapping for xgboost metric names to implemented metrics +xgboost_metrics = { + 'auc': 'Auc', + 'logloss': 'LogLoss', + 'aucpr': 'Auprc', +} diff --git a/MED3pa/visualization/mdr_visualization.py b/MED3pa/visualization/mdr_visualization.py new file mode 100644 index 0000000..b6bfa6e --- /dev/null +++ b/MED3pa/visualization/mdr_visualization.py @@ -0,0 +1,64 @@ +""" +The mdr_visualization.py module manages visualization methods for the Metrics by Declaration Rates (MDR) curves. +""" + +import matplotlib.pyplot as plt +import os +from typing import List, Optional + +from MED3pa.med3pa.results import Med3paResults + + +def visualize_mdr(result: Med3paResults, filename: str = 'mdr', linewidth: int = 1, metrics: Optional[List[str]] = None, + dr: Optional[int] = None, save: bool = True, show: bool = True, save_format: str = 'svg') -> None: + """ + Visualizes the MDR curves, and saves the plot if save is True. + + Args: + result (Med3paResults): The results of the experiment to visualize. + filename (str): The name of the file to be saved. Defaults to 'mdr'. + linewidth (int): The width of the lines in the plot. Defaults to 1. + metrics (List[str], optional): List of metrics to add in the plot. Defaults to None, which means all available + metrics. + dr (int, optional): Declaration rate applied to predictions. If specified, adds a line on the plot to the + corresponding declaration rate. + save (bool): Whether to save the plot. Defaults to True. + show (bool): Whether to show the plot in the terminal. Defaults to True. + save_format (str): The format of the saved plot. Defaults to 'svg'. + + """ + mdr_values = result.test_record.metrics_by_dr + + declaration_rates = sorted(map(int, mdr_values.keys())) + + if metrics is None: + metrics = ['Accuracy', 'Precision', 'Recall', 'F1Score', 'Specificity', 'Sensitivity', 'Auc', 'NPV', 'PPV'] + + for metric in metrics: + values = [] + for dr in declaration_rates: + if metric in mdr_values[dr]['metrics']: + values.append(mdr_values[dr]['metrics'][metric]) + elif metric in ['Positive%', 'population_percentage', 'min_confidence_level', 'mean_confidence_level']: + values.append(mdr_values[dr][metric]) + else: + values.append(None) # Handle missing values + + plt.plot(declaration_rates, values, label=metric, linewidth=linewidth) + + # If the dr parameter is different from None or 100, add the vertical line + if dr is not None and dr != 100: + plt.axvline(x=dr, color='k', linestyle='--', linewidth=linewidth) + + # Plot parameters + plt.xlabel("Declaration Rate") + plt.ylabel("Metric Value") + plt.legend() + plt.title("Metrics vs Declaration Rate") + plt.grid(True, linestyle='--', alpha=0.7, linewidth=2) + if save: + # Create the directory if it doesn't exist + os.makedirs(filename, exist_ok=True) + plt.savefig(f"{filename}.{save_format}", format=save_format) + if show: + plt.show() diff --git a/MED3pa/visualization/profiles_visualization.py b/MED3pa/visualization/profiles_visualization.py new file mode 100644 index 0000000..1c76e62 --- /dev/null +++ b/MED3pa/visualization/profiles_visualization.py @@ -0,0 +1,167 @@ +""" +The profiles_visualization.py module manages visualization methods for the resulting profiles of the MED3pa method. +""" + +import os +import re +import shutil +from jinja2 import Environment, FileSystemLoader +from typing import List, Optional + +from MED3pa.med3pa.results import Med3paResults + + +_template_folder = "tree_template" +_save_template_folder = None + + +def visualize_tree(result: Med3paResults, filename: str = 'profiles', dr: int = 100, samp_ratio: int = 0, + data_set: str = 'test', metrics_list: List | None = None, profile_depth: int = 4, + open_results: bool = True, port: int = 8000) -> None: + """ + Vizualization method for the profiles obtained from the MED3pa method. + + Args: + result (Med3paResults): The results of the experiment to visualize. + filename (str): The name of the file to be saved. Defaults to 'profiles'. + dr (int): Declaration rate of predictions for the visualization. Defaults to 100, meaning all predictions. + samp_ratio (int): The minimum samples ratio in each profile (in percentage). Defaults to 0, meaning no minimal + samples per profile. + data_set (str): The name of the data set to visualize. Defaults to 'test', options are 'train', 'valid', 'test'. + metrics_list (List | None): The list of metrics to visualize. Defaults to None, meaning shown metrics are: + ['Specificity', 'Sensitivity', 'NPV', 'PPV', 'AUC'] + profile_depth (int): Maximum profile depth. Defaults to 4. + open_results (bool): Whether to open the results of the experiment in the web browser. Defaults to True. + port (int): The local port used to host the python server to open the web browser. Defaults to port 8000. + """ + # Set folder path for the template + global _save_template_folder, _template_folder + _save_template_folder = os.path.dirname(filename) + "/" + _template_folder + + assert data_set in ['reference', 'test'], "Invalid data_set, must be in ['reference', 'test']" + + if metrics_list is None: + metrics_list = ['Specificity', 'Sensitivity', 'NPV', 'PPV', 'AUC'] + + rendered_html = _generate_tree_html(result=result, samp_ratio=samp_ratio, dr=dr, data_set=data_set, + metrics_list=metrics_list, + max_depth=profile_depth) + + # Save the template with the HTML file + template_path = os.path.join(os.path.dirname(__file__), _template_folder) + shutil.copytree(template_path, _save_template_folder, dirs_exist_ok=True) + # Save the HTML file + with open(filename + '.html', 'w') as f: + f.write(rendered_html) + print(f"Tree visualization generated: '{filename}.html'") + + if open_results: # Open html results file in web browser + import http.server + import socketserver + import threading + import webbrowser + + def start_server(): + handler = http.server.SimpleHTTPRequestHandler + with socketserver.TCPServer(("", port), handler) as httpd: + print(f"Serving at http://localhost:{port}") + httpd.serve_forever() + + # Start server in background + thread = threading.Thread(target=start_server, daemon=True) + thread.start() + + # Open HTML file in default browser + url = f"http://localhost:{port}/{filename}.html" + webbrowser.open(url) + + # Keep the script alive + input("Press Enter to stop the server...\n") + print("\n\nTo open the HTML file in the web browser:\n" + "1. Activate your conda environment: conda start ;\n" + f"2. Start a local Python server: python -m http.server {port};\n" + f"3. Open the following URL in your browser: http://localhost:{port}/{filename}.html") + + +def _generate_tree_html(result: Med3paResults, samp_ratio: int, dr: int, data_set: str, + metrics_list: Optional[List] = None, max_depth: Optional[int] = None) -> str: + """ + Generates the tree visualization HTML. + + Args: + result (Med3paResults): The results of the experiment to visualize. + samp_ratio (int): The minimum samples ratio in each profile (in percentage). + dr (int): Declaration rate of predictions for the visualization. + data_set (str): The name of the data set to visualize. Defaults to 'test', options are 'train', 'valid', 'test'. + metrics_list (List | None): The list of metrics to visualize. Defaults to None, meaning no metrics are shown. + max_depth (int): Maximum profile depth. + + Returns: + str: The HTML string of the visualization tree. + """ + global _save_template_folder, _template_folder + + # Get the absolute path of the current file's directory + template_path = os.path.join(os.path.dirname(__file__), _template_folder) + env = Environment(loader=FileSystemLoader(template_path)) + template = env.get_template('tree.html') + + # Read profiles for the specified data_set + profiles_to_visualize = _read_tree_section(result=result, samp_ratio=samp_ratio, dr=dr, data_set=data_set) + + # Show only metrics in metrics_list, if specified + if metrics_list is not None: + for profile in profiles_to_visualize: + all_metrics = profile.metrics + profile.metrics = {key: all_metrics.get(key) for key in metrics_list if key in all_metrics.keys()} + + if max_depth is not None: + profiles_to_visualize = [profile for profile in profiles_to_visualize if len(profile.path) <= max_depth] + + # Convert Profile objects to dicts before json serialization + profiles_to_visualize = [vars(profile) for profile in profiles_to_visualize] + + # Render the HTML with the list of nodes and base path + rendered_html = template.render( + nodes=profiles_to_visualize, + base_path=_save_template_folder + ) + return rendered_html + + +def _read_tree_section(result: Med3paResults, samp_ratio: int, dr: int, data_set: str) -> List: + """ + Generates the tree visualization HTML. + + Args: + result (Med3paResults): The results of the experiment to visualize. + dr (int): Declaration rate of predictions for the visualization. + samp_ratio (int): The minimum samples ratio in each profile (in percentage). + data_set (str): The name of the data set to visualize. Defaults to 'test', options are 'train', 'valid', 'test'. + + Returns: + List: The profiles to visualize in the HTML file. + """ + if data_set == 'reference': + assert result.reference_record is not None, ("MED3pa experiment not applied to reference data, no profiles " + "available.") + profiles_to_visualize = result.reference_record.profiles_manager.profiles_records[samp_ratio][dr] + else: + profiles_to_visualize = result.test_record.profiles_manager.profiles_records[samp_ratio][dr] + + # Round condition values + for profile in profiles_to_visualize: + profile.path = [re.sub(r'(?" +}; + +// Color Interpolation Function +function interpolateColor(color1, color2, factor) { + const hex = (color) => parseInt(color.slice(1), 16); + const r = (color) => (color >> 16) & 255; + const g = (color) => (color >> 8) & 255; + const b = (color) => color & 255; + + const c1 = hex(color1); + const c2 = hex(color2); + const rVal = Math.round(r(c1) + factor * (r(c2) - r(c1))); + const gVal = Math.round(g(c1) + factor * (g(c2) - g(c1))); + const bVal = Math.round(b(c1) + factor * (b(c2) - b(c1))); + + return `rgb(${rVal}, ${gVal}, ${bVal})`; +} + +// Determine color range dynamically based on parameter +function getColorForValue(value, max, min) { + const range = max - min; + const normalizedValue = Math.min(Math.max((value - min) / range, 0), 1); + let color; + + if (normalizedValue < 0.5) { + color = interpolateColor("#c90404", "#e07502", normalizedValue * 2); // Red to Orange + } else { + color = interpolateColor("#e07502", "#068a0c", (normalizedValue - 0.5) * 2); // Orange to Green + } + + return color; +} + +function toggleLegendVisibility(show) { + const legendContainer = document.querySelector(".legend-container"); + legendContainer.style.display = show ? "flex" : "none"; +} + + +// Apply color to nodes based on parameter and custom min/max range +function applyColorToNodes() { + const infoType = document.querySelector("input[name='info-type']:checked").value; + const parameter = document.getElementById("color-parameter").value; + if (parameter === "") return; // Exit if no parameter is selected + + const min = parseFloat(document.getElementById("color-min").value); + const max = parseFloat(document.getElementById("color-max").value); + + treeData.forEach((node) => { + const nodeElement = document.getElementById(`node-${node.id}`); + if (node[infoType] && node[infoType][parameter] != null && nodeElement) { + const value = node[infoType][parameter]; + const color = getColorForValue(value, max, min); + + const titleElement = nodeElement.querySelector('.node-title'); + titleElement.style.backgroundColor = color; + } + }); + toggleLegendVisibility(true); +} + +// Reset Node Colors +function resetNodeColors() { + treeData.forEach((node) => { + const nodeElement = document.getElementById(`node-${node.id}`); + if (nodeElement) { + const titleElement = nodeElement.querySelector('.node-title'); + titleElement.style.backgroundColor = ""; // Reset color + } + }); + toggleLegendVisibility(false) +} + +// Function to create node content +function createNodeContent(node, infoType, isPhantom = false) { + const container = document.createElement("div"); + container.className = "node-container"; + + const title = document.createElement("div"); + title.className = isPhantom ? "node-title lost-profile-title" : "node-title"; + title.innerText = isPhantom ? "Lost Profile" : `Profile ${node.id}`; + container.appendChild(title); + + const content = document.createElement("div"); + content.className = "node-content"; + + const info = node[infoType]; + + if (info && !isPhantom) { + if (infoType === "detectron_results") { + // Handle detectron_results specifically + if (info["Tests Results"]) { + content.innerHTML = info["Tests Results"] + .map((result) => { + // Extract relevant information + const strategy = `Strategy: ${result.Strategy}`; + const shiftProbOrPValue = + result.shift_probability !== undefined + ? `Shift Probability: ${result.shift_probability.toFixed(2)}` + : result.p_value !== undefined + ? `P-Value: ${result.p_value.toFixed(2)}` + : "No data available"; + return `${strategy}
${shiftProbOrPValue}`; + }) + .join("

"); // Add spacing between strategies + + } else { + content.innerText = "Detectron was not executed!"; + } + } else { + // Default case for other info types + content.innerHTML = Object.entries(info) + .map(([key, value]) => { + const formattedValue = + typeof value === "number" && !Number.isInteger(value) + ? value.toFixed(2) + : value; + const color = node.text_color && key in node.text_color ? node.text_color[key] : "black"; + if (node.highlight) { + content.style.fontSize = "20px"; + } + return `

${key}: ${formattedValue !== null ? formattedValue : "N/A"}

`; + }) + .join(""); + + // content.innerHTML = Object.entries(info) + // .map(([key, value]) => { + // const formattedValue = + // typeof value === "number" && !Number.isInteger(value) + // ? value.toFixed(3) + // : value; + // return `${key}: ${formattedValue !== null ? formattedValue : "N/A"}`; + // }) + // .join("
"); + } + } else { + content.innerText = "No data available"; + } + + container.appendChild(content); + return container; +} + + +// Build the tree +// Build the tree and re-render nodes based on selected info type +function buildTree(data, rootElement, parentPath = ["*"], infoType = "node information") { + const rootNode = data.find(node => JSON.stringify(node.path) === JSON.stringify(parentPath)); + if (rootNode) { + const li = document.createElement("li"); + + // const conditionLabel = document.createElement("div"); + // conditionLabel.className = "condition-label"; + // conditionLabel.innerText = '*'; + // li.appendChild(conditionLabel); + + const nodeContent = createNodeContent(rootNode, infoType); + nodeContent.id = `node-${rootNode.id}`; + li.appendChild(nodeContent); + + const rootUl = document.createElement("ul"); + li.appendChild(rootUl); + rootElement.appendChild(li); + + buildChildren(data, rootUl, rootNode.path, infoType); + } +} + +// Recursively build children nodes based on selected info type +function buildChildren(data, parentElement, parentPath, infoType) { + const children = data.filter(node => JSON.stringify(node.path.slice(0, -1)) === JSON.stringify(parentPath)); + + children.forEach((node) => { + const li = document.createElement("li"); + + const conditionLabel = document.createElement("div"); + conditionLabel.className = "condition-label"; + conditionLabel.innerText = node.path[node.path.length - 1]; + li.appendChild(conditionLabel); + + const nodeContent = createNodeContent(node, infoType); + nodeContent.id = `node-${node.id}`; + li.appendChild(nodeContent); + + parentElement.appendChild(li); + + const hasChildren = data.some(n => JSON.stringify(n.path.slice(0, -1)) === JSON.stringify(node.path)); + if (hasChildren) { + const ul = document.createElement("ul"); + buildChildren(data, ul, node.path, infoType); + li.appendChild(ul); + } + }); + + if (children.length === 1) { + const phantomLi = document.createElement("li"); + + const existingCondition = children[0].path[children[0].path.length - 1]; + const oppositeCondition = existingCondition.includes("<=") + ? existingCondition.replace("<=", ">") + : existingCondition.replace(">", "<="); + + const phantomConditionLabel = document.createElement("div"); + phantomConditionLabel.className = "condition-label"; + phantomConditionLabel.innerText = oppositeCondition; + phantomLi.appendChild(phantomConditionLabel); + + const phantomNodeContent = createNodeContent({}, infoType, true); + phantomLi.appendChild(phantomNodeContent); + + parentElement.appendChild(phantomLi); + } +} + +// Update tree display based on selected information type +function updateTreeDisplay() { + const infoType = document.querySelector("input[name='info-type']:checked").value; + const treeRoot = document.getElementById("tree-root"); + treeRoot.innerHTML = ""; // Clear existing tree + buildTree(treeData, treeRoot, ["*"], infoType); // Rebuild tree with selected info type +} + +// Update options based on selected info type +function updateColorParameterOptions(infoType) { + const colorParameterSelect = document.getElementById("color-parameter"); + colorParameterSelect.innerHTML = ""; + + const sampleNode = treeData.find(node => node[infoType]); + if (sampleNode && sampleNode[infoType]) { + Object.keys(sampleNode[infoType]).forEach(key => { + const option = document.createElement("option"); + option.value = key; + option.textContent = key; + colorParameterSelect.appendChild(option); + }); + } + + document.getElementById("color-min").value = infoType === "metrics" ? 0 : 0; + document.getElementById("color-max").value = infoType === "metrics" ? 1 : 100; +} + +// Disable checkboxes based on available data +function updateCheckboxAvailability() { + // console.log(document.getElementById("general-info-checkbox")); + document.getElementById("general-info-checkbox").disabled = !treeData.some(node => node["node information"]); + document.getElementById("performance-info-checkbox").disabled = !treeData.some(node => node.metrics); + document.getElementById("shift-detection-checkbox").disabled = !treeData.some(node => node.detectron_results); +} + +// Initialize the Tree +function initializeTree() { + const treeRoot = document.getElementById("tree-root"); + buildTree(treeData, treeRoot); + // console.log(treeData) + updateCheckboxAvailability(); + updateColorParameterOptions("node information"); +} + +const treeContainer = document.getElementById("tree-root"); +panzoom(treeContainer); + +// function downloadTreeAsPNG() { +// const treeContainer = document.getElementById("tree-container"); +// +// htmlToImage.toPng(treeContainer) +// .then((dataUrl) => { +// const link = document.createElement("a"); +// link.href = dataUrl; +// link.download = "tree.png"; +// link.click(); +// }) +// .catch((error) => { +// console.error("Error generating PNG with html-to-image: ", error); +// }); +// } + +function downloadTreeAsPDF() { + const treeContainer = document.getElementById("tree-container"); + + // Ensure domtoimage is loaded + if (typeof window.domtoimage === "undefined") { + console.error("dom-to-image-more is not loaded!"); + return; + } + + window.domtoimage.toPng(treeContainer) // Convert to PNG for embedding in PDF + .then((dataUrl) => { + const pdf = new window.jspdf.jsPDF(); // Create a new PDF document + const imgWidth = 190; // Width in mm + const imgHeight = (treeContainer.clientHeight / treeContainer.clientWidth) * imgWidth; // Maintain aspect ratio + + pdf.addImage(dataUrl, "PNG", 10, 10, imgWidth, imgHeight); + pdf.save("tree.pdf"); // Download as PDF + }) + .catch((error) => { + console.error("Error generating PDF with dom-to-image-more: ", error); + }); +} + +// function downloadTreeAsSVG() { +// const treeContainer = document.getElementById("tree-container"); +// +// // Ensure domtoimage is loaded +// if (typeof window.domtoimage === "undefined") { +// console.error("dom-to-image-more is not loaded!"); +// return; +// } +// +// window.domtoimage.toSvg(treeContainer) +// .then((dataUrl) => { +// const link = document.createElement("a"); +// link.href = dataUrl; +// link.download = "tree.svg"; +// link.click(); +// }) +// .catch((error) => { +// console.error("Error generating SVG with dom-to-image-more: ", error); +// }); +// } + +function downloadTreeAsSVG() { + const treeContainer = document.getElementById("tree-container"); + + if (!treeContainer) { + console.error("Tree container not found!"); + return; + } + + const rect = treeContainer.getBoundingClientRect(); + const width = Math.max(1, rect.width); + const height = Math.max(1, rect.height); + + // Create the SVG element + const svg = document.createElementNS("http://www.w3.org/2000/svg", "svg"); + svg.setAttribute("width", width); + svg.setAttribute("height", height); + svg.setAttribute("viewBox", `0 0 ${width} ${height}`); + svg.setAttribute("xmlns", "http://www.w3.org/2000/svg"); + + // Background + const bgRect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); + bgRect.setAttribute("width", "100%"); + bgRect.setAttribute("height", "100%"); + bgRect.setAttribute("fill", "white"); + svg.appendChild(bgRect); + + const nodes = treeContainer.querySelectorAll(".node-container"); + let positions = {}; + let bg_colors = {}; + + // Zoom factor + let zoomFactor = parseFloat(document.getElementById("tree-root").style.transform.match(/matrix\(([^)]+)\)/)[1].split(',')[0]); + + // Define round rectangles + const cornerRadius = 5 * zoomFactor; // Scale based on zoom + + const clipPath = ` + + `; + + // Create an SVG filter for the shadow around rectangles + const svgFilter = document.createElementNS("http://www.w3.org/2000/svg", "filter"); + svgFilter.setAttribute("id", "shadow"); + svgFilter.setAttribute("x", "-20%"); + svgFilter.setAttribute("y", "-20%"); + svgFilter.setAttribute("width", "140%"); + svgFilter.setAttribute("height", "140%"); + + // Create a Gaussian blur effect for the shadow + const feGaussianBlur = document.createElementNS("http://www.w3.org/2000/svg", "feGaussianBlur"); + feGaussianBlur.setAttribute("in", "SourceAlpha"); + feGaussianBlur.setAttribute("stdDeviation", 3 * zoomFactor); // Adjust blur size + svgFilter.appendChild(feGaussianBlur); + + // Merge original shape with shadow + const feMerge = document.createElementNS("http://www.w3.org/2000/svg", "feMerge"); + const feMergeNode1 = document.createElementNS("http://www.w3.org/2000/svg", "feMergeNode"); + // feMergeNode1.setAttribute("in", "offsetBlur"); + feMerge.appendChild(feMergeNode1); + const feMergeNode2 = document.createElementNS("http://www.w3.org/2000/svg", "feMergeNode"); + feMergeNode2.setAttribute("in", "SourceGraphic"); + feMerge.appendChild(feMergeNode2); + + svgFilter.appendChild(feMerge); + + // Append filter to SVG defs + const defs = document.createElementNS("http://www.w3.org/2000/svg", "defs"); + defs.appendChild(svgFilter); + svg.appendChild(defs); + + // First pass: Draw nodes, titles, and content + nodes.forEach((node, index) => { + let nodeId; + if (node.id === ""){ + const parentNode = node.closest("li").parentElement.closest("li")?.querySelector(".node-container"); + nodeId = 'child_' + parentNode.id; + } + else { + nodeId = node.id; + } + const nodeRect = node.getBoundingClientRect(); + const x = nodeRect.left - rect.left; + const y = nodeRect.top - rect.top; + const width = nodeRect.width; + const height = nodeRect.height; + positions[nodeId] = { x, y, width, height }; + + // Get background color from .node-title + const titleElement = node.querySelector(".node-title"); + let backgroundColor = "white"; // Default + if (titleElement) { + const computedStyle = window.getComputedStyle(titleElement); + backgroundColor = computedStyle.backgroundColor || "white"; + } + bg_colors[nodeId] = backgroundColor + + // Draw node rectangle with white background color + const svgRect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); + svgRect.setAttribute("x", x); + svgRect.setAttribute("y", y); + svgRect.setAttribute("rx", 5 * zoomFactor); + svgRect.setAttribute("ry", 5 * zoomFactor); + svgRect.setAttribute("width", width); + svgRect.setAttribute("height", height); + svgRect.setAttribute("fill", "white"); + svgRect.setAttribute("stroke", "black"); + svgRect.setAttribute("stroke-width", "0.25") + svgRect.setAttribute("filter", "url(#shadow)"); // Apply the shadow filter + svgRect.setAttribute("style", "opacity: 0.5;") + svg.appendChild(svgRect); + + // **If it's the first node, add the condition rectangle ("*")** + if (index === 0) { + // const conditionHeight = 20 * zoomFactor; // Height of the condition box + // const conditionRect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); + // conditionRect.setAttribute("x", x); + // conditionRect.setAttribute("y", y); + // conditionRect.setAttribute("rx", 5 * zoomFactor); + // conditionRect.setAttribute("ry", 5 * zoomFactor); + // conditionRect.setAttribute("width", width); + // conditionRect.setAttribute("height", conditionHeight); + // conditionRect.setAttribute("fill", bg_colors[nodeId]); // Match background color + // conditionRect.setAttribute("stroke", "black"); + // svg.appendChild(conditionRect); + // Define a path with rounded top corners only + const clipRect = document.createElementNS("http://www.w3.org/2000/svg", "path"); + const r = 5 * zoomFactor; // Radius for top corners + const w = width; + const h = 20 * zoomFactor; // Height of the condition box + const conditionPath = document.createElementNS("http://www.w3.org/2000/svg", "path"); + conditionPath.setAttribute("d", `M${x + r},${y} + H${x + w - r} + A${r},${r} 0 0 1 ${x + w},${y + r} + V${y + h} + H${x} + V${y + r} + A${r},${r} 0 0 1 ${x + r},${y} + Z`); + conditionPath.setAttribute("fill", bg_colors[nodeId]); + conditionPath.setAttribute("stroke", "black"); + conditionPath.setAttribute("stroke-width", "0.25") + svg.appendChild(conditionPath); + + // // Add "*" text inside the condition rectangle + // const conditionText = document.createElementNS("http://www.w3.org/2000/svg", "text"); + // conditionText.setAttribute("x", x + width / 2); + // conditionText.setAttribute("y", y + conditionHeight / 2 + 5*zoomFactor); + // conditionText.setAttribute("font-size", 14 * zoomFactor); + // conditionText.setAttribute("fill", bg_colors[node.id] === "rgb(49, 49, 49)" ? "white" : "black"); + // conditionText.setAttribute("text-anchor", "middle"); + // conditionText.setAttribute("dominant-baseline", "middle"); + // conditionText.textContent = "*"; + // svg.appendChild(conditionText); + } + + // Get node content and handle

elements + const contentElement = node.querySelector(".node-content"); + if (contentElement) { + let paragraphs = contentElement.querySelectorAll("p"); // Select all

elements + let fontsize = parseFloat(window.getComputedStyle(contentElement).fontSize) || 11; + let lineHeight = (fontsize * 1.33) * zoomFactor; + let startY = y + lineHeight + 20 * zoomFactor;//+ height / 2 - (paragraphs.length) * (lineHeight / 2) + 20 * zoomFactor; + + paragraphs.forEach((p, i) => { + const contentText = document.createElementNS("http://www.w3.org/2000/svg", "text"); + contentText.setAttribute("x", x + 10 * zoomFactor); + contentText.setAttribute("y", startY + i * lineHeight); + // console.log(contentElement.style.fontSize); + contentText.setAttribute("font-size", fontsize * zoomFactor); + + // Extract color from

style + let textColor = p.style.color || "black"; + if (textColor === "black"){ + contentText.setAttribute("fill", "black") + } + else { + contentText.setAttribute("fill", "white"); + } + contentText.setAttribute("text-anchor", "left"); + contentText.textContent = p.textContent.trim(); + + svg.appendChild(contentText); + + if (textColor !== "black"){ + // Append text to SVG temporarily to measure size + const textBBox = contentText.getBBox(); // Get actual width & height + + // Use getComputedTextLength() to get text width + let textWidth = p.clientWidth;// contentText.getComputedTextLength(); + let textHeight = fontsize * zoomFactor * 1.2; // Approximate height + + // Create a black rectangle behind the text + const backgroundRect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); + backgroundRect.setAttribute("x", x + 5 * zoomFactor); // Slightly offset + backgroundRect.setAttribute("y", startY + i * lineHeight - textHeight * 0.8); // Align with text + backgroundRect.setAttribute("width", (textWidth + 10) * zoomFactor); // Add padding + backgroundRect.setAttribute("height", textHeight); // Adjust height slightly + backgroundRect.setAttribute("fill", textColor); + backgroundRect.setAttribute("rx", 4 * zoomFactor); // Rounded corners (optional) + + // Move rect before text to layer it behind + svg.insertBefore(backgroundRect, contentText); + + + // // Add rectangle behind text + // // Measure text width + // let textWidth = p.textContent.length * fontsize * 0.5 * zoomFactor; // Approximation + // let textHeight = fontsize * zoomFactor * 1.2; // Slightly larger than font-size + // // Create a black rectangle behind the text + // const backgroundRect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); + // backgroundRect.setAttribute("x", x + 5 * zoomFactor); // Slightly offset + // backgroundRect.setAttribute("y", startY + i * lineHeight - textHeight * 0.8); // Align with text + // backgroundRect.setAttribute("width", textWidth); + // backgroundRect.setAttribute("height", textHeight); + // backgroundRect.setAttribute("fill", textColor); + // backgroundRect.setAttribute("rx", 4 * zoomFactor); // Rounded corners (optional) + // // Append rectangle before text (so it's behind) + // svg.appendChild(backgroundRect); + } + + }); + } + + // // Get node content and handle
+ // const contentElement = node.querySelector(".node-content"); + // if (contentElement) { + // let contentLines = contentElement.innerHTML.split(//i); + // let lineHeight = 14 * zoomFactor; + // let startY = y + height / 2 - (contentLines.length ) * (lineHeight / 2) + 20 * zoomFactor; + // + // contentLines.forEach((line, i) => { + // const contentText = document.createElementNS("http://www.w3.org/2000/svg", "text"); + // contentText.setAttribute("x", x + 10 * zoomFactor); + // contentText.setAttribute("y", startY + i * lineHeight); + // contentText.setAttribute("font-size", 12 * zoomFactor); + // contentText.setAttribute("fill", "black"); + // contentText.setAttribute("text-anchor", "left"); + // contentText.textContent = line.trim(); + // svg.appendChild(contentText); + // }); + // } + }); + + // Second pass: Draw connections and condition labels + const conditions = treeContainer.querySelectorAll(".condition-label"); + conditions.forEach((conditionLabel) => { + const parentNode = conditionLabel.closest("li").parentElement.closest("li")?.querySelector(".node-container"); + const childNode = conditionLabel.closest("li")?.querySelector(".node-container"); + + if (!parentNode || !childNode) return; + + const parentID = parentNode.id; + let childID; + if (childNode.id === ""){ + childID = 'child_' + parentNode.id; + } + else { + childID = childNode.id; + } + if (positions[parentID] && positions[childID]) { + const { x: px, y: py, width: pw, height: ph } = positions[parentID]; + const { x: cx, y: cy, width: cw, height: ch } = positions[childID]; + + const x1 = px + pw / 2; + const y1 = py + ph; + const x2 = cx + cw / 2; + const y2 = cy; + + const midX = (x1 + x2) / 2; // Control point in the middle + const midY = (y1 + y2) / 2; // Adjust curve height (negative = curve upwards) + const controlX1 = x1; // Control points for first downward curve + const controlY1 = y1 + 50; + + const controlX2 = midX; // Control points for final downward curve + const controlY2 = y2 - 50; + + const path = document.createElementNS("http://www.w3.org/2000/svg", "path"); + path.setAttribute("d", `M ${x1},${y1} + C ${x1},${y2} ${x2},${y1} ${x2},${y2}`); + path.setAttribute("stroke", "black"); + path.setAttribute("fill", "none"); + path.setAttribute("stroke-width", 2 * zoomFactor); + svg.appendChild(path) + + // // Draw line from parent to child + // const line = document.createElementNS("http://www.w3.org/2000/svg", "line"); + // line.setAttribute("x1", px + pw / 2); + // line.setAttribute("y1", py + ph); + // line.setAttribute("x2", cx + cw / 2); + // line.setAttribute("y2", cy); + // line.setAttribute("stroke", "black"); + // line.setAttribute("stroke-width", "1"); + // svg.appendChild(line); + + // Add condition label inside a small rectangle + let conditionText = conditionLabel.innerText.trim(); + // Loop through the dictionary and replace occurrences + for (const [key, value] of Object.entries(replacements)) { + conditionText = conditionText.replace(new RegExp(key, "g"), value); + } + // conditionText = conditionText.replace('service_group_', '').replace('admission_group_','admission_') + console.log(conditionText) + const conditionRectWidth = cw; + const conditionRectHeight = 20 * zoomFactor; + const condX = cx; + const condY = cy ; // - conditionRectHeight + + // // Small rectangle for the condition label + // const conditionRect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); + // conditionRect.setAttribute("x", condX); + // conditionRect.setAttribute("y", condY); + // conditionRect.setAttribute("rx", 5 * zoomFactor); + // conditionRect.setAttribute("ry", 5 * zoomFactor); + // conditionRect.setAttribute("width", conditionRectWidth); + // conditionRect.setAttribute("height", conditionRectHeight); + // conditionRect.setAttribute("fill", bg_colors[childID]); + // conditionRect.setAttribute("stroke", "black"); + // svg.appendChild(conditionRect); + + // Define a path with rounded top corners only + const clipRect = document.createElementNS("http://www.w3.org/2000/svg", "path"); + const r = 5 * zoomFactor; // Radius for top corners + const w = conditionRectWidth; + const h = conditionRectHeight; + const conditionPath = document.createElementNS("http://www.w3.org/2000/svg", "path"); + conditionPath.setAttribute("d", `M${condX + r},${condY} + H${condX + w - r} + A${r},${r} 0 0 1 ${condX + w},${condY + r} + V${condY + h} + H${condX} + V${condY + r} + A${r},${r} 0 0 1 ${condX + r},${condY} + Z`); + conditionPath.setAttribute("fill", bg_colors[childID]); + conditionPath.setAttribute("stroke", "black"); + conditionPath.setAttribute("stroke-width", "0.25") + svg.appendChild(conditionPath); + + // Condition label text inside the small rectangle + const conditionTextElement = document.createElementNS("http://www.w3.org/2000/svg", "text"); + conditionTextElement.setAttribute("x", condX + conditionRectWidth / 2); + conditionTextElement.setAttribute("y", condY + conditionRectHeight / 2 + 5*zoomFactor); + conditionTextElement.setAttribute("font-size", 14 * zoomFactor); + if (bg_colors[childID] === "rgb(49, 49, 49)"){ + conditionTextElement.setAttribute("fill", "white"); + } + else { + conditionTextElement.setAttribute("fill", "black"); + } + conditionTextElement.setAttribute("text-anchor", "middle"); + conditionTextElement.textContent = conditionText; + console.log(conditionTextElement) + svg.appendChild(conditionTextElement); + } + }); + + // Serialize and download + const serializer = new XMLSerializer(); + let svgString = serializer.serializeToString(svg); + // // Ensure correct encoding for `<` and `>` + // svgString = svgString.replace(/</g, "\\textless").replace(/>/g, "\\textgreater"); + const blob = new Blob([svgString], { type: "image/svg+xml" }); + const link = document.createElement("a"); + link.href = URL.createObjectURL(blob); + link.download = "tree.svg"; + link.click(); +} + + + + + + + + + +function updateLegendLabels() { + const minRangeInput = document.getElementById("color-min"); + const maxRangeInput = document.getElementById("color-max"); + + // Convert values to numbers + const minRange = parseFloat(minRangeInput.value); + const maxRange = parseFloat(maxRangeInput.value); + + // console.log(typeof minRange) + document.getElementById("min-legend").textContent = minRange; + document.getElementById("quarter-legend").textContent = ((maxRange - minRange) * 0.25 + minRange).toFixed(3); + document.getElementById("half-legend").textContent = ((maxRange - minRange) * 0.5 + minRange).toFixed(3); + document.getElementById("three-quarters-legend").textContent = ((maxRange - minRange) * 0.75 + minRange).toFixed(3); + document.getElementById("max-legend").textContent = maxRange; +} + +function toggleColorSection() { + const colorOptions = document.getElementById("color-options"); + const colorToggle = document.getElementById("color-toggle"); + + if (colorToggle.checked) { + colorOptions.style.display = "block"; + } else { + colorOptions.style.display = "none"; + } +} + +// Event listeners to update legend when min or max range changes +document.getElementById("color-min").addEventListener("input", updateLegendLabels); +document.getElementById("color-max").addEventListener("input", updateLegendLabels); + +// Initial call to set legend values on page load +updateLegendLabels(); + + +// Event listeners +document.getElementById("color-nodes-button").addEventListener("click", applyColorToNodes); +document.getElementById("reset-color-button").addEventListener("click", resetNodeColors); +document.getElementById("download-png-button").addEventListener("click", downloadTreeAsSVG); +document.getElementById("download-pdf-button").addEventListener("click", downloadTreeAsPDF); +document.querySelectorAll("input[name='info-type']").forEach((radio) => { + radio.addEventListener("change", (event) => { + updateColorParameterOptions(event.target.value); + updateLegendLabels(); + updateTreeDisplay() + }); +}); + +function updateColorToggleAvailability() { + const infoType = document.querySelector("input[name='info-type']:checked").value; + const colorToggle = document.getElementById("color-toggle"); + const colorToggleContainer = document.querySelector(".color-toggle"); + + if (infoType === "detectron_results") { + // Disable the toggle and hide the color options + colorToggle.checked = false; + colorToggle.disabled = true; + colorToggleContainer.classList.add("disabled-toggle"); + document.getElementById("color-options").style.display = "none"; + } else { + // Enable the toggle + colorToggle.disabled = false; + colorToggleContainer.classList.remove("disabled-toggle"); + } +} + +// Attach this function to the event listener for radio buttons +document.querySelectorAll("input[name='info-type']").forEach((radio) => { + radio.addEventListener("change", updateColorToggleAvailability); +}); + +// Initial call to ensure toggle is correctly enabled/disabled on page load +updateColorToggleAvailability(); + +document.addEventListener("DOMContentLoaded", function() { + initializeTree(); +}); + +// document.addEventListener("DOMContentLoaded", function () { +// const performanceRadio = document.getElementById("performance-info-checkbox"); +// const metricsFilter = document.getElementById("metrics-filter"); +// const metricsSelect = document.getElementById("metrics-select"); +// +// const availableMetrics = ["Accuracy", "Auc", "Auprc", "BalancedAccuracy", "F1Score", "LogLoss", +// "MCC", "NPV", "PPV", "Precision", "Recall", "Sensitivity", "Specificity"]; // Example metrics +// +// // Populate dropdown +// availableMetrics.forEach(metric => { +// const option = document.createElement("option"); +// option.value = metric; +// option.textContent = metric; +// metricsSelect.appendChild(option); +// }); +// +// // Show/hide filter when "Node Performance" is selected +// document.querySelectorAll("input[name='info-type']").forEach(radio => { +// radio.addEventListener("change", function () { +// metricsFilter.style.display = performanceRadio.checked ? "block" : "none"; +// }); +// }); +// }); +// document.addEventListener("click", function () { +// document.addEventListener("click", function () { +// const performanceRadio = document.getElementById("performance-info-checkbox"); +// const metricsFilter = document.getElementById("metrics-filter"); +// const metricsSelect = document.getElementById("metrics-select"); +// +// // List of available metrics (update if needed) +// const availableMetrics = ["Accuracy", "Auc", "Auprc", "BalancedAccuracy", "F1Score", "LogLoss", "MCC", "NPV", "PPV", "Precision", "Recall", "Sensitivity", "Specificity"]; +// +// // Populate dropdown with metric options +// availableMetrics.forEach(metric => { +// const option = document.createElement("option"); +// option.value = metric; +// option.textContent = metric; +// option.selected = true; // Select all by default +// metricsSelect.appendChild(option); +// }); +// +// // Store original `.node-content` values for each node **before any filtering happens** +// document.querySelectorAll(".node-container").forEach(node => { +// const contentDiv = node.querySelector(".node-content"); +// if (contentDiv) { +// contentDiv.dataset.originalContent = contentDiv.innerHTML; // Save original full content +// } +// }); +// +// // Show/hide filter when "Node Performance" is selected +// document.querySelectorAll("input[name='info-type']").forEach(radio => { +// radio.addEventListener("change", function () { +// if (performanceRadio.checked) { +// metricsFilter.style.display = "block"; +// applyMetricFilter(); // Apply filter immediately when switching +// } else { +// metricsFilter.style.display = "none"; +// restoreAllMetrics(); // Restore all metrics when switching away +// } +// }); +// }); +// +// // Apply filtering when metric selection changes +// metricsSelect.addEventListener("change", applyMetricFilter); +// +// // function applyMetricFilter() { +// // const selectedMetrics = Array.from(metricsSelect.selectedOptions).map(option => option.value); +// // +// // document.querySelectorAll(".node-container").forEach(node => { +// // const contentDiv = node.querySelector(".node-content"); +// // if (contentDiv) { +// // // **Retrieve original content before filtering** +// // const originalContent = contentDiv.dataset.originalContent; +// // if (!originalContent) return; +// // +// // // **Reset content first to prevent accumulating removals** +// // let filteredMetrics = originalContent.split("
").filter(metricLine => { +// // return selectedMetrics.some(selectedMetric => +// // metricLine.trim().toLowerCase().startsWith(selectedMetric.toLowerCase() + ":") +// // ); +// // }); +// // +// // // **Ensure at least one metric remains to prevent empty display** +// // contentDiv.innerHTML = filteredMetrics.length > 0 ? filteredMetrics.join("
") : "(No metrics selected)"; +// // } +// // }); +// // } +// +// function applyMetricFilter() { +// const selectedMetrics = Array.from(metricsSelect.selectedOptions).map(option => option.value.toLowerCase()); +// +// document.querySelectorAll(".node-container").forEach(node => { +// node.querySelectorAll(".metric").forEach(metricElement => { +// const metricName = metricElement.dataset.metric.toLowerCase(); +// metricElement.style.display = selectedMetrics.includes(metricName) ? "block" : "none"; +// }); +// }); +// } +// +// +// function restoreAllMetrics() { +// document.querySelectorAll(".node-container").forEach(node => { +// const contentDiv = node.querySelector(".node-content"); +// if (contentDiv && contentDiv.dataset.originalContent) { +// contentDiv.innerHTML = contentDiv.dataset.originalContent; // Restore full original content +// } +// }); +// } +// }); diff --git a/MED3pa/visualization/tree_template/style.css b/MED3pa/visualization/tree_template/style.css new file mode 100644 index 0000000..6e80baa --- /dev/null +++ b/MED3pa/visualization/tree_template/style.css @@ -0,0 +1,378 @@ +/* Global Reset */ +* { + margin: 0; + padding: 0; +} + +/* Body and Container Styles */ +body { + display: flex; + justify-content: center; + align-items: center; + height: 100vh; + background-color: #f4f4f4; +} + +.main-container { + display: flex; + width: 100%; + max-width: 1200px; +} + +.canvas-container { + width: 60%; + height: 80vh; + background-color: #fff; + border: 1px solid #ccc; + border-radius: 8px; + overflow: hidden; + position: relative; + box-shadow: 0px 4px 12px rgba(0, 0, 0, 0.1); +} + +.tree-container { + cursor: grab; +} + +.controls { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + padding-left: 20px; +} + +/* Tree Structure Styles */ +.tree ul { + padding-top: 20px; + position: relative; + display: flex; + justify-content: center; + gap: 10px; + transition: all 0.5s; +} + +.tree li { + list-style-type: none; + text-align: center; + position: relative; + padding: 20px 5px 0 5px; + display: inline-block; + vertical-align: top; + transition: all 0.5s; +} + +/* Node container styles */ +.node-container { + border: 1px solid #ccc; + background-color: #f9f9f9; + border-radius: 5px; + box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1); + overflow: hidden; + display: inline-block; + font-family: Arial, sans-serif; + white-space: nowrap; + min-width: 180px; +} + +.node-title { + background-color: #313131; + color: white; + font-weight: bold; + padding: 5px; + text-align: center; + font-size: 12px; +} + +.node-content { + padding: 10px; + font-size: 11px; + color: #333; + text-align: left; +} + +.condition-label { + font-size: 10px; + color: #555; + background: #fff; + padding: 2px 4px; + border-radius: 4px; + white-space: nowrap; + margin-bottom: 5px; + font-family: Arial, sans-serif; + font-weight: 600; +} + +/* Connector Lines */ +.tree li::before, +.tree li::after { + content: ''; + position: absolute; + top: 0; + width: 50%; + height: 20px; + border-top: 1px solid #ccc; +} + +.tree li::before { + right: 50%; +} + +.tree li::after { + left: 50%; + border-left: 1px solid #ccc; +} + +/* Adjust connectors for specific cases */ +.tree li:only-child::before, +.tree li:only-child::after { + display: none; +} + +.tree li:first-child::before, +.tree li:last-child::after { + border: 0 none; +} + +.tree li:last-child::before { + border-right: 1px solid #ccc; + border-radius: 0 5px 0 0; +} + +.tree li:first-child::after { + border-radius: 5px 0 0 0; +} + +.tree ul ul::before { + content: ''; + position: absolute; + top: 0; + left: 50%; + border-left: 1px solid #ccc; + width: 0; + height: 20px; +} + +/* Special styling for Lost Profile nodes */ +.lost-profile-title { + background-color: #ddddde; + color: white; +} + +/* Node Parameters Container */ +.node-parameters-container { + width: 30%; + padding: 10px; + background-color: #ffffff; + border: 1px solid #e0e0e0; + border-radius: 8px; + font-family: Arial, sans-serif; + box-shadow: 0px 4px 12px rgba(0, 0, 0, 0.1); +} + +.node-parameters-header { + display: flex; + align-items: center; + font-weight: bold; + font-size: 1.2em; + margin-bottom: 12px; +} + +.node-parameters-header .icon { + margin-right: 8px; +} + +.focus-section, +.color-section { + margin-bottom: 16px; +} + +.focus-section label, +.color-section label { + font-weight: bold; + display: block; + margin-bottom: 4px; +} + +.focus-options label, +.color-options label { + font-weight: normal; + display: block; + margin: 4px 0; +} + +.focus-options label input[type="radio"] { + margin-right: 8px; +} + +.color-toggle { + display: flex; + align-items: center; + gap: 5px; +} + +.color-toggle input[type="checkbox"] { + display: none; +} + +.color-toggle label { + width: 40px; + height: 20px; + background-color: #e0e0e0; + border-radius: 10px; + position: relative; + cursor: pointer; +} + +.color-toggle label:before { + content: ''; + position: absolute; + top: 2px; + left: 2px; + width: 16px; + height: 16px; + background-color: #007bff; + border-radius: 50%; + transition: all 0.3s; +} + +.color-toggle input:checked + label:before { + transform: translateX(20px); +} + +/* Color Options */ +.color-options { + display: none; + margin-top: 8px; + grid-template-columns: 1fr 1fr; + gap: 8px; +} + +.color-options select, +.color-options input { + padding: 4px; + border: 1px solid #cccccc; + border-radius: 4px; +} + +#color-parameter { + width: 100%; + margin: 0px; +} + +.min-max-section { + display: flex; + justify-content: space-between; + gap: 8px; + margin-top: 8px; +} + +.min-max-section div { + flex: 1; +} + +.min-max-section input { + width: 95%; +} + +/* Buttons */ +button { + width: 100%; + padding: 8px; + border: none; + border-radius: 4px; + font-weight: bold; + cursor: pointer; + transition: background-color 0.3s; + margin-top: 8px; +} + +button#color-nodes-button { + background-color: #007bff; + color: white; +} + +button#color-nodes-button:hover { + background-color: #0056b3; +} + +button#reset-color-button { + background-color: #6c757d; + color: white; +} + +button#reset-color-button:hover { + background-color: #5a6268; +} + +button#download-png-button { + background-color: #28a745; + color: white; + margin-top: 8px; +} + +button#download-png-button:hover { + background-color: #218838; +} + +button#download-pdf-button { + background-color: #28a745; + color: white; + margin-top: 8px; +} + +button#download-pdf-button:hover { + background-color: #218838; +} + +.color-buttons { + display: flex; + justify-content: space-between; + gap: 8px; + margin-top: 8px; + width: 100%; +} + +button { + flex: 1; + padding: 8px; +} + +/* Legend Styling */ +.legend-container { + position: absolute; + bottom: 20px; + right: 20px; + display: none; + align-items: center; + gap: 10px; +} + +.legend-bar { + width: 20px; + height: 200px; + background: linear-gradient(to bottom, #068a0c, #e07502, #c90404); + border: 1px solid #000; +} + +.legend-labels { + display: flex; + flex-direction: column; + justify-content: space-between; + height: 200px; +} + +.legend-labels div { + font-size: 14px; + text-align: right; +} + +/* Style for the toggle when disabled */ +.disabled-toggle { + opacity: 0.2; + pointer-events: none; /* Prevent interaction */ +} + +.disabled-toggle label { + cursor: not-allowed; +} \ No newline at end of file diff --git a/MED3pa/visualization/tree_template/tree.html b/MED3pa/visualization/tree_template/tree.html new file mode 100644 index 0000000..459768c --- /dev/null +++ b/MED3pa/visualization/tree_template/tree.html @@ -0,0 +1,89 @@ + + + + + + Tree Visualization + + + +

+
+
    +
    +
    +
    +
    100
    +
    75
    +
    50
    +
    25
    +
    0
    +
    +
    +
    + +
    +
    + Node Parameters +
    +
    + +
    + + + + + + + + + +
    +
    +
    + +
    + + +
    +
    + + +
    +
    + + +
    +
    + + +
    +
    +
    + + +
    +
    +
    + + + +
    + + + +
    + + + + + + + + + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..706e8b8 --- /dev/null +++ b/README.md @@ -0,0 +1,131 @@ +# MED3pa Package + +## Table of Contents +- [Overview](#overview) +- [Key Functionalities](#key-functionalities) +- [Subpackages](#subpackages) +- [Getting Started with the Package](#getting-started) + - [Installation](#installation) + - [A Simple Example](#a-simple-example) +- [Acknowledgement](#acknowledgement) +- [References](#references) +- [Authors](#authors) +- [Statement](#statement) +- [Supported Python Versions](#supported-python-versions) + +## Overview + +Overview + +The **MED3pa** package is specifically designed to address critical challenges in deploying machine learning models, particularly focusing on the robustness and reliability of models under real-world conditions. It provides comprehensive tools for evaluating model stability and performance in the face of **covariate shifts**, **uncertainty**, and **problematic data profiles**. + +## Key Functionalities + +- ** Model Confidence Estimation**: Through the MED3pa subpackage, the package measures the predictive confidence at both individual and group (profile) levels. This helps in understanding the reliability of model predictions and in making informed decisions based on model outputs. + +- **Identification of Problematic Profiles**: MED3pa analyzes data profiles for whom the BaseModel consistently leads to poor model performance. This capability allows developers to refine training datasets or retrain models to handle these edge cases effectively. + +## Subpackages + +

    + Overview +

    + +The package is structured into four distinct subpackages: + +- **datasets**: Stores and manages the dataset. +- **models**: Handles ML models operations. +- **med3pa**: Evaluates the model’s performance & extracts problematic profiles. + +This modularity allows users to easily integrate and utilize specific functionalities tailored to their needs without dealing with unnecessary complexities. + +## Getting Started with the Package + +To get started with MED3pa, follow the installation instructions and usage examples provided in the documentation. + +### Installation + +```bash +pip install MED3pa +``` + +### A simple exemple +We have created a [simple example](https://github.com/MEDomics-UdeS/MED3pa/tree/main/examples) of using the MED3pa package. +[See the full example here](https://github.com/MEDomics-UdeS/MED3pa/tree/main/examples/oym_example.ipynb) +```python +from MED3pa.datasets import DatasetsManager +from MED3pa.med3pa import Med3paExperiment +from MED3pa.models import BaseModelManager +from MED3pa.visualization.mdr_visualization import visualize_mdr +from MED3pa.visualization.profiles_visualization import visualize_tree + +... + +# Initialize the DatasetsManager +datasets = DatasetsManager() +datasets.set_from_data(dataset_type="testing", + observations=x_evaluation.to_numpy(), + true_labels=y_evaluation, + column_labels=x_evaluation.columns) +# Initialize the BaseModelManager +base_model_manager = BaseModelManager(model=clf) + +# Execute the MED3PA experiment +results = Med3paExperiment.run( + datasets_manager=datasets, + base_model_manager=base_model_manager, + **med3pa_params +) + +# Save the results to a specified directory +results.save(file_path='results/oym') + +# Visualize results +visualize_mdr(result=results, filename='results/oym/mdr') +visualize_tree(result=results, filename='results/oym/profiles') + +``` + +## Acknowledgement +MED3pa is an open-source package developed at the [MEDomicsLab](https://www.medomicslab.com/) laboratory. We welcome any contribution and feedback. + +## Authors +* [Olivier Lefebvre: ](https://www.linkedin.com/in/olivier-lefebvre-bb8837162/) Student (Ph. D. Computer science) at Université de Sherbrooke +* [Lyna Chikouche: ](https://www.linkedin.com/in/lynahiba-chikouche-62a5181bb/) Research intern at MEDomics-Udes laboratory. +* [Ludmila Amriou: ](https://www.linkedin.com/in/ludmila-amriou-875b58238//) Research intern at MEDomics-Udes laboratory. +* [Martin Vallières: ](https://www.linkedin.com/in/martvallieres/) Associate professor, Department of Oncology at McGill University + +## Statement + +This package is part of https://www.medomics.ai/, a package providing research utility tools for developing precision medicine applications. + +``` +Copyright (C) 2024 MEDomics consortium + +GPLV3 LICENSE SYNOPSIS + +Here's what the license entails: + +1. Anyone can copy, modify and distribute this software. +2. You have to include the license and copyright notice with each and every distribution. +3. You can use this software privately. +4. You can use this software for commercial purposes. +5. If you dare build your business solely from this code, you risk open-sourcing the whole code base. +6. If you modify it, you have to indicate changes made to the code. +7. Any modifications of this code base MUST be distributed with the same license, GPLv3. +8. This software is provided without warranty. +9. The software author or license can not be held liable for any damages inflicted by the software. +``` + +More information about the [LICENSE can be found here](https://github.com/MEDomics-UdeS/MEDimage/blob/main/LICENSE.md) + +## Supported Python Versions + +The **MED3pa** package is developed and tested with Python 3.12.3. + +Additionally, it is compatible with the following Python versions: +- Python 3.11.x +- Python 3.10.x +- Python 3.9.x + +While the package may work with other versions of Python, these are the versions we officially support and recommend. \ No newline at end of file diff --git a/docs/MED3pa.datasets.rst b/docs/MED3pa.datasets.rst new file mode 100644 index 0000000..62aa042 --- /dev/null +++ b/docs/MED3pa.datasets.rst @@ -0,0 +1,59 @@ +datasets subpackage +======================= +Overview +--------- +The ``datasets`` subpackage is a core component designed to manage the complexities of data handling and preparation. +It provides a structured and extensible way to load, process, and manipulate datasets from various sources, +ensuring they are ready for use in machine learning models or data analysis tasks. The subpackage is built to accommodate a variety of data formats and includes functionalities for masking, +sampling, and managing multiple datasets, making it versatile for different phases of data-driven projects. + +This subpackage is composed of the following modules: + +- **loading_context**: Manages the strategy for loading data, allowing flexibility in the source file format. +- **loading_strategies**: Implements specific strategies for different file formats. +- **manager**: Coordinates access and manipulations across multiple datasets used in ML pipelines. +- **masked**: Allows multiple operation on the datasets, like sampling, refining, cloning...etc. + + +The package includes the following classes: + +.. image:: ./diagrams/datasets.svg + :alt: UML class diagram of the subpackage. + :align: center + +.. raw:: html + +
    + +**loading\_context module** +--------------------------- + +.. automodule:: MED3pa.datasets.loading_context + :members: + :undoc-members: + :show-inheritance: + +**loading\_strategies module** +------------------------------ + +.. automodule:: MED3pa.datasets.loading_strategies + :members: + :undoc-members: + :show-inheritance: + +**manager module** +------------------------------ + +.. automodule:: MED3pa.datasets.manager + :members: + :undoc-members: + :show-inheritance: + +**masked module** +------------------------------ + +.. automodule:: MED3pa.datasets.masked + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/MED3pa.med3pa.rst b/docs/MED3pa.med3pa.rst new file mode 100644 index 0000000..7e312c0 --- /dev/null +++ b/docs/MED3pa.med3pa.rst @@ -0,0 +1,101 @@ +med3pa subpackage +===================== + +Overview +--------- + +The ``med3pa`` subpackage represents a modularized and enhanced version of the original `"MED3PA code" `__, +designed to augment clinical decision-making by providing a robust framework for evaluating and managing **model uncertainty** in healthcare applications. +It introduces a sophisticated approach to assessing model performance that transcends **traditional global metrics**, +focusing instead on the concept of predictive confidence at both individual and aggregated levels. + +Key Components: +~~~~~~~~~~~~~~~~ +- **Individualized Predictive Confidence (IPC)**: This component employs ``regression`` models to estimate the predictive confidence for individual data points. The IPC model, adaptable with various regression algorithms like ``Random Forest``, is particularly aimed at quantifying the uncertainty associated with each prediction, allowing for a detailed analysis of model reliability. + +- **Aggregated Predictive Confidence (APC)**: Contrasting with IPC, APC focuses on groups of similar data points, using a ``decision tree regressor`` to analyze uncertainty in aggregated profiles. This method helps identify patterns or groups where the model's performance might be suboptimal, facilitating targeted improvements. + +- **Mixed Predictive Confidence (MPC)**: This model combines results from IPC and APC to derive a composite measure of confidence. MPC values are then used to further scrutinize the model's performance and to identify problematic profiles where the base model might fail. + +Advanced Analysis with MDR: +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The med3pa approach innovatively employs metrics by declaration rate (MDR), which evaluates model metrics at various confidence thresholds. +This methodology not only highlights how the model performs across different confidence levels but also **pinpoints specific profiles** (groups of data points) that might be **problematic**. +By doing so, it aids in the proactive identification and mitigation of potential model failures. + +Extensibility for Integration: +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Our package is meticulously designed to facilitate and accommodate **integration** with other methods. +This extensibility allows ``med3pa`` to assess shifts in data distributions, **especially focusing on the problematic profiles identified through MPC**. +Such integration enhances the comprehensive assessment of how external changes or shifts might affect model reliability over time, +ensuring that the model remains robust and accurate in dynamic clinical environments. + +In essence, ``med3pa`` is dedicated to advancing safer clinical deployments by providing tools that not only predict outcomes but also critically analyze and **improve the understanding of where and why predictions might fail.** +This helps ensure that deployed models are not just effective but are also reliable and trustworthy in real-world settings. + +this subpackage includes the following classes: + +.. image:: ./diagrams/med3pa.svg + :alt: UML class diagram of the subpackage. + :align: center + +.. raw:: html + +
    + +uncertainty module +------------------- + +.. automodule:: MED3pa.med3pa.uncertainty + :members: + :undoc-members: + :show-inheritance: + +models module +--------------------------- + +.. automodule:: MED3pa.med3pa.models + :members: + :undoc-members: + :show-inheritance: + +tree module +------------------------- + +.. automodule:: MED3pa.med3pa.tree + :members: + :undoc-members: + :show-inheritance: + +Profiles module +----------------------------- + +.. automodule:: MED3pa.med3pa.profiles + :members: + :undoc-members: + :show-inheritance: + +MDR module +------------------------------- + +.. automodule:: MED3pa.med3pa.mdr + :members: + :undoc-members: + :show-inheritance: + +experiment module +------------------------------- + +.. automodule:: MED3pa.med3pa.experiment + :members: + :undoc-members: + :show-inheritance: + +compraison module +------------------------------- + +.. automodule:: MED3pa.med3pa.comparaison + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/MED3pa.models.rst b/docs/MED3pa.models.rst new file mode 100644 index 0000000..59d22bd --- /dev/null +++ b/docs/MED3pa.models.rst @@ -0,0 +1,113 @@ +models subpackage +===================== + +Overview +---------- +The ``models`` subpackage serves as a comprehensive backbone for managing and utilizing machine learning models across various aspects of the package. It is meticulously designed to support the development, evaluation, and optimization of models, +ensuring compatibility and efficiency in integrating with ``med3pa`` methodology. +This subpackage leverages several design patterns, such as **Factory, Singleton, and Prototype,** to ensure robustness, modularity, and scalability. +Through its structured approach, This subpackage offers a robust framework that includes abstract base classes for uniformity in model operations, concrete implementations for specialized algorithms, +and utility tools for precise model evaluation and data handling. + +This subpackage is composed of the following modules: + +- **factories.py**: Utilizes the factory design pattern to facilitate flexible and scalable model instantiation, enhancing the modularity of model creation. + +- **abstract_models.py**: Defines the abstract base classes for all model types, including general models, classification models, and regression models. These classes provide a common interface for model operations. + +- **concrete_classifiers.py**: Contains concrete implementations of classification models, like the XGBoostModel. + +- **concrete_regressors.py**: Provides implementations for regression models, such as RandomForestRegressor and DecisionTreeRegressor. + +- **abstract_metrics.py**: Provides the abstract base classes for all evaluation metrics, centralizing the logic for metric calculations across different model types. + +- **classification_metrics.py**: Implements a variety of evaluation metrics specifically for classification tasks, such as accuracy, precision, and recall. + +- **regression_metrics.py**: Hosts evaluation metrics for regression tasks, including mean squared error and R2 score, crucial for assessing model performance. + +- **data_strategies.py**: Offers various strategies for preparing data, ensuring compatibility and optimal formatting for model training and evaluation. + +- **base.py**: Manages a singleton base model, responsible for the instantiation and cloning of the base model across the methods, ensuring consistency and reliability in model management. + +The package includes the following classes: + +.. image:: ./diagrams/models.svg + :alt: UML class diagram of the subpackage. + :align: center + +.. raw:: html + +
    + +**factories module** +------------------------------ + +.. automodule:: MED3pa.models.factories + :members: + :undoc-members: + :show-inheritance: + +**abstract\_models module** +------------------------------------- + +.. automodule:: MED3pa.models.abstract_models + :members: + :undoc-members: + :show-inheritance: + +**concrete\_classifiers module** +------------------------------------------ + +.. automodule:: MED3pa.models.concrete_classifiers + :members: + :undoc-members: + :show-inheritance: + +**concrete\_regressors module** +----------------------------------------- + +.. automodule:: MED3pa.models.concrete_regressors + :members: + :undoc-members: + :show-inheritance: + +**abstract\_metrics module** +-------------------------------------- + +.. automodule:: MED3pa.models.abstract_metrics + :members: + :undoc-members: + :show-inheritance: + +**classification\_metrics module** +-------------------------------------------- + +.. automodule:: MED3pa.models.classification_metrics + :members: + :undoc-members: + :show-inheritance: + +**regression\_metrics module** +---------------------------------------- + +.. automodule:: MED3pa.models.regression_metrics + :members: + :undoc-members: + :show-inheritance: + +**data\_strategies module** +------------------------------------- + +.. automodule:: MED3pa.models.data_strategies + :members: + :undoc-members: + :show-inheritance: + +**base module** +------------------------- + +.. automodule:: MED3pa.models.base + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/MED3pa.rst b/docs/MED3pa.rst new file mode 100644 index 0000000..3df55ec --- /dev/null +++ b/docs/MED3pa.rst @@ -0,0 +1,8 @@ + +.. toctree:: + :maxdepth: 4 + + MED3pa.datasets + MED3pa.models + MED3pa.med3pa + diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 0000000..7c92c37 --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,15 @@ +.wy-nav-content { + max-width: 100% !important; +} + +.github-link { + display: flex; + align-items: center; + text-decoration: none; + font-size: 16px; + margin-left: 15px; +} + +.github-link i { + margin-right: 8px; +} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..ef99d1f --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,61 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +sys.path.insert(0, os.path.abspath('..')) + + +# -- Project information ----------------------------------------------------- + +project = 'Med3pa documentation' +copyright = '2025, MEDomics Consortium' +author = 'Olivier Lefebvre, Lyna Chikouche, Ludmila Amriou, Martin Vallières' + +# The full version, including alpha/beta/rc tags +release = '1.0.0' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = ['sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# This is usually at the top of conf.py +html_static_path = ['_static'] + +# Append this to include your custom CSS +html_css_files = [ + 'custom.css', +] \ No newline at end of file diff --git a/docs/datasets_tutorials.rst b/docs/datasets_tutorials.rst new file mode 100644 index 0000000..073a96e --- /dev/null +++ b/docs/datasets_tutorials.rst @@ -0,0 +1,226 @@ +Working with datasets subpackage +-------------------------------- +The ``datasets`` subpackage is designed to provide robust and flexible data loading and management functionalities tailored for machine learning models. +This tutorial will guide you through using this subpackage to handle and prepare your data efficiently. + +Using the DatasetsManager class +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``DatasetsManager`` class in the ``MED3pa.datasets`` submodule is designed to facilitate the management of various datasets needed for model training and evaluation. This tutorial provides a step-by-step guide on setting up and using the ``DatasetsManager`` to handle data efficiently. + +Step 1: Importing the DatasetsManager +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, import the ``DatasetsManager`` from the ``MED3pa.datasets`` submodule: + +.. code-block:: python + + from MED3pa.datasets import DatasetsManager + +Step 2: Creating an Instance of DatasetsManager +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Create an instance of ``DatasetsManager``. This instance will manage all operations related to datasets: + +.. code-block:: python + + manager = DatasetsManager() + +Step 3: Loading Datasets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +With the ``DatasetsManager``, you can load various segments of your base model datasets, such as training, validation, reference, and testing datasets. You don't need to load all datasets at once. Provide the path to your dataset and the name of the target column: + +**Loading from File** + +.. code-block:: python + + manager.set_from_file(dataset_type="training", file='./path_to_training_dataset.csv', target_column_name='target_column') + +**Loading from NumPy Arrays** + +You can also load the datasets as NumPy arrays. For this, you need to specify the features, true labels, and column labels as a list (excluding the target column) if they are not already set. + +.. code-block:: python + + import numpy as np + import pandas as pd + + df = pd.read_csv('./path_to_validation_dataset.csv') + + # Extract labels and features + X_val = df.drop(columns='target_column').values + y_val = df['target_column'].values + + # Example of setting data from numpy arrays + manager.set_from_data(dataset_type="validation", observations=X_val, true_labels=y_val) + +Step 4: Ensuring Feature Consistency +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Upon loading the first dataset, the ``DatasetsManager`` automatically extracts and stores the names of features. You can retrieve the list of these features using: + +.. code-block:: python + + features = manager.get_column_labels() + +Ensure that the features of subsequent datasets (e.g., validation or testing) match those of the initially loaded dataset to avoid errors and maintain data consistency. + +Step 5: Retrieving Data +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Retrieve the loaded data in different formats as needed. + +**As NumPy Arrays** + +.. code-block:: python + + observations, labels = manager.get_dataset_by_type(dataset_type="training") + +**As a MaskedDataset Instance** + +To work with the data encapsulated in a ``MaskedDataset`` instance, which might include more functionalities, retrieve it by setting ``return_instance`` to ``True``: + +.. code-block:: python + + training_dataset = manager.get_dataset_by_type(dataset_type="training", return_instance=True) + +Step 6: Getting a Summary +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +You can print a summary of the ``DatasetsManager`` to see the status of the datasets: + +.. code-block:: python + + manager.summarize() + +Step 7: Saving and Resetting Datasets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +You can save a specific dataset to a CSV file or reset all datasets managed by the ``DatasetsManager``. + +**Save to CSV** + +.. code-block:: python + + manager.save_dataset_to_csv(dataset_type="training", file_path='./path_to_save_training_dataset.csv') + +**Reset Datasets** + +.. code-block:: python + + manager.reset_datasets() + manager.summarize() # Verify that all datasets are reset + +Summary of Outputs +^^^^^^^^^^^^^^^^^^^ + +When you run the ``summary`` method, you should get an output similar to this, indicating the status and details of each dataset: + +.. code-block:: none + + training_set: {'num_samples': 151, 'num_features': 23, 'has_pseudo_labels': False, 'has_pseudo_probabilities': False, 'has_confidence_scores': False} + validation_set: {'num_samples': 1000, 'num_features': 10, 'has_pseudo_labels': False, 'has_pseudo_probabilities': False, 'has_confidence_scores': False} + reference_set: Not set + testing_set: Not set + column_labels: ['feature_1', 'feature_2', ..., 'feature_23'] + +Using the MaskedDataset Class +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The ``MaskedDataset`` class, a crucial component of the ``MED3pa.datasets`` submodule, facilitates nuanced data operations that are essential for custom data manipulation and model training processes. This tutorial details common usage scenarios of the ``MaskedDataset``. + +Step 1: Importing Necessary Modules +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Begin by importing the ``MaskedDataset`` and ``DatasetsManager``, along with NumPy for additional data operations: + +.. code-block:: python + + from MED3pa.datasets import MaskedDataset, DatasetsManager + import numpy as np + +Step 2: Loading Data with DatasetsManager +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Retrieve the dataset as a ``MaskedDataset`` instance: + +.. code-block:: python + + manager = DatasetsManager() + manager.set_from_file(dataset_type="training", file='./path_to_training_dataset.csv', target_column_name='target_column') + training_dataset = manager.get_dataset_by_type(dataset_type="training", return_instance=True) + +Step 3: Performing Operations on MaskedDataset +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Once you have your dataset loaded as a ``MaskedDataset`` instance, you can perform various operations: + +**Cloning the Dataset** + +Create a copy of the dataset to ensure the original data remains unchanged during experimentation: + +.. code-block:: python + + cloned_instance = training_dataset.clone() + +**Sampling the Dataset** + +Randomly sample a subset of the dataset, useful for creating training or validation splits: + +.. code-block:: python + + sampled_instance = training_dataset.sample(N=20, seed=42) + +**Refining the Dataset** + +Refine the dataset based on a boolean mask, which is useful for filtering out unwanted data points: + +.. code-block:: python + + mask = np.random.rand(len(training_dataset)) > 0.5 + remaining_samples = training_dataset.refine(mask=mask) + +**Setting Pseudo Labels and Probabilities** + +Set pseudo labels and probabilities for the dataset, for this you only need to pass the pseudo_probabilities along with the threshold to extract the pseudo_labels from: + +.. code-block:: python + + pseudo_probs = np.random.rand(len(training_dataset)) + training_dataset.set_pseudo_probs_labels(pseudo_probabilities=pseudo_probs, threshold=0.5) + +**Getting Feature Vectors and Labels** + +Retrieve the feature vectors, true labels, and pseudo labels: + +.. code-block:: python + + observations = training_dataset.get_observations() + true_labels = training_dataset.get_true_labels() + pseudo_labels = training_dataset.get_pseudo_labels() + +**Getting Confidence Scores** + +Get the confidence scores if available: + +.. code-block:: python + + confidence_scores = training_dataset.get_confidence_scores() + +**Converting to DataFrame and Saving to CSV** + +You can save the dataset as a .csv file. Using `save_to_csv` and providing the path this will save the observations, true_labels, pseudo_labels and pseudo_probabilities, alongside confidence_scores if they were set: + +.. code-block:: python + + df = training_dataset.to_dataframe() + training_dataset.save_to_csv('./path_to_save_training_dataset.csv') + +**Getting Dataset Information** + +Get detailed information about the dataset, or you can directly use ``summary``: + +.. code-block:: python + + training_dataset.summarize() + +When you run the ``summarize`` method, you should get an output similar to this, indicating the status and details of the dataset: + +.. code-block:: none + + Number of samples: 151 + Number of features: 23 + Has pseudo labels: False + Has pseudo probabilities: False + Has confidence scores: False diff --git a/docs/diagrams/datasets.svg b/docs/diagrams/datasets.svg new file mode 100644 index 0000000..680f171 --- /dev/null +++ b/docs/diagrams/datasets.svg @@ -0,0 +1 @@ +datasets subpackageDataLoadingContextselected_strategy: DataLoadingStrategyset_strategy(strategy: DataLoadingStrategy): voidget_strategy(): DataLoadingStrategyload_as_np(file_path: str, target_column_name: str): Tuple[List[str], np.ndarray, np.ndarray]DataLoadingStrategyexecute(path_to_file: str, target_column_name: str): Tuple[List[str], np.ndarray, np.ndarray]CSVDataLoadingStrategyexecute(path_to_file: str, target_column_name: str): Tuple[List[str], np.ndarray, np.ndarray]DatasetsManagerbase_model_training_set: MaskedDatasetbase_model_validation_set: MaskedDatasetreference_set: MaskedDatasettesting_set: MaskedDatasetcolumn_labels: List[str]set_from_file(dataset_type: str, file: str, target_column_name: str): voidset_from_data(dataset_type: str, features: np.ndarray, true_labels: np.ndarray, column_labels: list = None): voidset_column_labels(columns: list): voidget_column_labels(): List[str]get_info(show_details: bool = True): dictsummarize(): voidreset_datasets(): voidget_dataset_by_type(dataset_type: str, return_instance: bool = False): MaskedDatasetsave_dataset_to_csv(dataset_type: str, file_path: str): void__get_base_model_training_data(return_instance: bool = False): Union[tuple, MaskedDataset]__get_base_model_validation_data(return_instance: bool = False): Union[tuple, MaskedDataset]__get_reference_data(return_instance: bool = False): Union[tuple, MaskedDataset]__get_testing_data(return_instance: bool = False): Union[tuple, MaskedDataset]MaskedDatasetobservations: np.ndarraytrue_labels: np.ndarraypseudo_labels: np.ndarraypseudo_probabilities: np.ndarrayconfidence_scores: np.ndarrayindices: np.ndarrayoriginal_indices: np.ndarraysample_counts: np.ndarraycolumn_labels: list__init__(observations: np.ndarray, true_labels: np.ndarray, column_labels: list = None)__getitem__(index: int): tuple__len__(): intrefine(mask: np.ndarray): intoriginal(): MaskedDatasetreset_indices(): voidsample_uniform(N: int, seed: int): MaskedDatasetsample_random(N: int, seed: int): MaskedDatasetget_observations(): np.ndarrayget_pseudo_labels(): np.ndarrayget_true_labels(): np.ndarrayget_pseudo_probabilities(): np.ndarrayget_confidence_scores(): np.ndarrayset_pseudo_probs_labels(pseudo_probabilities: np.ndarray, threshold=0.5): voidset_confidence_scores(confidence_scores: np.ndarray): voidset_pseudo_labels(pseudo_labels: np.ndarray): voidclone(): MaskedDatasetget_info(): dictsummarize(): voidto_dataframe(): pd.DataFramesave_to_csv(file_path: str): void11«uses» \ No newline at end of file diff --git a/docs/diagrams/detectron.svg b/docs/diagrams/detectron.svg new file mode 100644 index 0000000..9634266 --- /dev/null +++ b/docs/diagrams/detectron.svg @@ -0,0 +1 @@ +Detectron SubpackageDetectronEnsemblebase_model_manager: BaseModelManagerens_size: intcdcs: list[Model]evaluate_ensemble(datasets, n_runs, samples_size, training_params, set, patience, allow_margin, margin): DetectronRecordsManagerDetectronRecordseed: intmodel_id: intoriginal_count: intvalidation_auc: floattest_auc: floatpredicted_probabilities: np.ndarrayupdated_count: intrejected_samples: intupdate(validation_auc, test_auc, predicted_probabilities, count): voidto_dict(): dictDetectronRecordsManagerrecords: list[dict]sample_size: intidx: int__seed: intsampling_counts: intseed(seed: int): voidupdate(val_data_x, val_data_y, sample_size, model, model_id, predicted_probabilities, test_data_x, test_data_y): voidfreeze(): voidget_record(): pd.DataFramesave(path): voidload(path): DetectronRecordsManagercounts(max_ensemble_size): np.ndarrayrejection_rates(max_ensemble_size): np.ndarraypredicted_probabilities(max_ensemble_size): np.ndarrayrejected_counts(max_ensemble_size): np.ndarrayEarlyStopperpatience: intmin_delta: floatbest: floatwait: intmode: strupdate(metric: float): boolDetectronStrategyexecute(calibration_records, test_records): dictOriginalDisagreementStrategyexecute(calibration_records, test_records): dictMannWhitneyStrategyexecute(calibration_records, test_records): dictEnhancedDisagreementStrategyexecute(calibration_records, test_records): dictDetectronResultcal_record: DetectronRecordsManagertest_record: DetectronRecordsManagertest_results: dictcalibration_trajectories(): pd.DataFrametest_trajectories(): pd.DataFrameget_experiments_results(): dictanalyze_results(strategies): dictsave(file_path, file_name): voidDetectronExperimentrun(datasets, training_params, base_model_manager, samples_size, calib_result, ensemble_size, num_calibration_runs, patience, allow_margin, margin): tupleDetectronComparisonresults1_path: strresults2_path: strdetectron_results_comparaison: dictconfig_file: dict_check_experiment_name(): Nonecompare_detectron_results(): Nonecompare_config(): Nonecompare_experiments(): Nonesave(directory_path: str): None«uses»«uses»*«return»«return»«uses» \ No newline at end of file diff --git a/docs/diagrams/med3pa.svg b/docs/diagrams/med3pa.svg new file mode 100644 index 0000000..934aada --- /dev/null +++ b/docs/diagrams/med3pa.svg @@ -0,0 +1 @@ +Med3pa SubpackageUncertaintyMetriccalculate(x: np.ndarray, predicted_prob: np.ndarray, y_true: np.ndarray): np.ndarrayAbsoluteErrorcalculate(x: np.ndarray, predicted_prob: np.ndarray, y_true: np.ndarray): np.ndarrayUncertaintyCalculatormetric: UncertaintyMetriccalculate_uncertainty(x: np.ndarray, predicted_prob: np.ndarray, y_true: np.ndarray): np.ndarrayIPCModelmodel: Anyparams: dictgrid_search_params: dictoptimized: boolpretrained: boolmodel_name: strsupported_ipc_models(): listsupported_models_params(): Dict[str, Dict[str, Any]]optimize(param_grid: dict, cv: int, x: np.ndarray, error_prob: np.ndarray, sample_weight: np.ndarray = None)train(x: np.ndarray, error_prob: np.ndarray): Nonepredict(x: np.ndarray): np.ndarrayevaluate(X: np.ndarray, y: np.ndarray, eval_metrics: List[str], print_results: bool = False): Dict[str, float]get_info(): Dict[str, Any]save_model(file_path: str): Noneload_model(file_path: str): NoneAPCModelmodel: AnytreeRepresentation: TreeRepresentationfeatures: List[str]params: dictgrid_search_params: dictoptimized: boolpretrained: boolsupported_models_params(): Dict[str, Dict[str, Any]]train(x: np.ndarray, error_prob: np.ndarray): Noneoptimize(param_grid: dict, cv: int, x: np.ndarray, error_prob: np.ndarray, sample_weight: np.ndarray = None)predict(X: np.ndarray): np.ndarrayevaluate(X: np.ndarray, y: np.ndarray, eval_metrics: List[str], print_results: bool = False): Dict[str, float]get_info(): Dict[str, Any]save_model(file_path: str): Noneload_model(file_path: str): Noneload_tree(file_path: str): NoneMPCModelIPC_values: np.ndarrayAPC_values: np.ndarraypredict(): np.ndarrayTreeRepresentationfeatures: listhead: _TreeNodenb_nodes: intbuild_tree(dtr: DecisionTreeRegressorModel, X: DataFrame, y: Series, node_id: int, path: list): _TreeNodeget_all_nodes(): listsave_tree(file_path: str): None_TreeNodec_left: _TreeNodec_right: _TreeNodevalue: floatvalue_max: floatsamples_ratio: floatthreshold: floatfeature: strfeature_id: intnode_id: intpath: listassign_node(X: Union[DataFrame, Series]): floatget_all_nodes(): listto_dict(): dictMed3paRecordmetrics_by_dr: Dict[int, Dict]models_evaluation: Dict[str, Dict]profiles_manager: ProfilesManagerdatasets: Dict[int, MaskedDataset]experiment_config: Dict[str, Any]tree: Dict[str, Any]set_metrics_by_dr(metrics_by_dr: Dict): Noneset_profiles_manager(profile_manager: ProfilesManager): Noneset_models_evaluation(ipc_evaluation: Dict, apc_evaluation: Dict = None): Noneset_tree(tree: TreeRepresentation): Noneset_dataset(mode: str, dataset: MaskedDataset): Nonesave(file_path: str): NoneMed3paResultsreference_record: Med3paRecordtest_record: Med3paRecordexperiment_config: Dict[str, Any]detectron_results: DetectronResultipc_model: IPCModelapc_model: APCModelset_detectron_results(detectron_results: DetectronResult = None): Noneset_experiment_config(config: Dict[str, Any]): Noneset_models(ipc_model: IPCModel, apc_model: APCModel = None): Nonesave(file_path: str): Nonesave_models(file_path: str, mode: str = 'all'): NoneMed3paExperimentrun()_run_by_set()Med3paDetectronExperimentrun()Med3paComparisonresults1_path: strresults2_path: strcompare_profiles_metrics(): voidcompare_profiles_detectron_results(): voidcompare_global_metrics(): voidcompare_models_evaluation(): voidcompare_config(): voidcompare_experiments(): voidsave(directory_path: str): voidProfilenode_id: intpath: list[str]mean_value: floatmetrics: dictnode_information: dictdetectron_results: dictto_dict(save_all: bool = True): dictupdate_detectron_results(detectron_results: dict): voidupdate_metrics_results(metrics: dict): voidupdate_node_information(info: dict): voidProfilesManagerprofiles_records: dictlost_profiles_records: dictfeatures: list[str]insert_profiles(dr: int, min_samples_ratio: int, profiles: list[Profile]): voidinsert_lost_profiles(dr: int, min_samples_ratio: int, profiles: list[Profile]): voidget_profiles(min_samples_ratio: int = None, dr: int = None): dictget_lost_profiles(min_samples_ratio: int = None, dr: int = None): dicttransform_to_profiles(profiles_list: list[dict], to_dict: bool = False): list[Union[dict, Profile]]MDRCalculator_get_min_confidence_score(): float_calculate_metrics(): dict_list_difference_by_key(): List[Profile]_filter_by_profile(): tuplecalc_metrics_by_dr(): dictcalc_profiles_deprecated(): dictcalc_profiles(): dictcalc_metrics_by_profiles(): voiddetectron_by_profiles():dict«uses»«uses»**«uses»«uses»«returns»«uses»«uses»«uses»«uses» \ No newline at end of file diff --git a/docs/diagrams/models.svg b/docs/diagrams/models.svg new file mode 100644 index 0000000..ecad627 --- /dev/null +++ b/docs/diagrams/models.svg @@ -0,0 +1 @@ +models subpackageEvaluationMetricget_metric(metric_name: str = '') : functionsupported_metrics() : List[str]ClassificationEvaluationMetricsaccuracy(y_true, y_pred, sample_weight) :floatrecall(y_true, y_pred, sample_weight) : floatroc_auc(y_true, y_pred, sample_weight) : floataverage_precision(y_true, y_pred, sample_weight) : floatmatthews_corrcoef(y_true, y_pred, sample_weight) : floatprecision(y_true, y_pred, sample_weight) : floatf1_score(y_true, y_pred, sample_weight) : floatsensitivity(y_true, y_pred, sample_weight) : floatspecificity(y_true, y_pred, sample_weight) : floatppv(y_true, y_pred, sample_weight) : floatnpv(y_true, y_pred, sample_weight) : floatbalanced_accuracy(y_true, y_pred, sample_weight) : floatlog_loss(y_true, y_pred, sample_weight) : floatRegressionEvaluationMetricsmean_squared_error(y_true, y_pred, sample_weight) : floatroot_mean_squared_error(y_true, y_pred, sample_weight) : floatmean_absolute_error(y_true, y_pred, sample_weight) : floatr2_score(y_true, y_pred, sample_weight) : floatModelmodel: anymodel_class: typeparams: dictdata_preparation_strategy: DataPreparingStrategypickled_model: booleanevaluate(X, y, eval_metrics, print_results) : dictget_model() : anyget_model_type() : typeget_data_strategy() : DataPreparingStrategyget_params() : dictget_info() : dictset_model(model) : voidset_params(params) : voidupdate_params(params) : voidset_data_strategy(strategy) : voidsave(path) : voidClassificationModeltrain(x_train, y_train, x_validation, y_validation, training_parameters, balance_train_classes) : voidpredict(X, return_proba, threshold) : np.ndarraytrain_to_disagree(x_train, y_train, x_validation, y_validation, x_test, y_test, training_parameters, balance_train_classes, N) : voidRegressionModeltrain(x_train, y_train, x_validation, y_validation, training_parameters) : voidpredict(X) : np.ndarrayDataPreparingStrategyexecute(features, labels, weights) : objectToDmatrixStrategyToNumpyStrategyToDataframesStrategyXGBoostModelRandomForestRegressorModelDecisionTreeRegressorModelBaseModelManagerbaseModel : Modelset_base_model(model : Model) : voidget_instance() : Modelclone_base_model() : ModelModelFactoryget_factory(model_type) : ModelFactorycreate_model_with_hyperparams(model_type, hyperparams) : Modelcreate_model_from_pickled(pickled_file_path) : ModelXGBoostFactorycreate_model_with_hyperparams(hyperparams) : XGBoostModelcreate_model_from_pickled(loaded_model) : XGBoostModel1«uses»«uses»«returns» \ No newline at end of file diff --git a/docs/diagrams/package.svg b/docs/diagrams/package.svg new file mode 100644 index 0000000..1062902 --- /dev/null +++ b/docs/diagrams/package.svg @@ -0,0 +1,69 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/diagrams/package_white_bg.svg b/docs/diagrams/package_white_bg.svg new file mode 100644 index 0000000..5087ab6 --- /dev/null +++ b/docs/diagrams/package_white_bg.svg @@ -0,0 +1,408 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Identification of low-confidence Profiles + diff --git a/docs/diagrams/subpackages.svg b/docs/diagrams/subpackages.svg new file mode 100644 index 0000000..5a3059c --- /dev/null +++ b/docs/diagrams/subpackages.svg @@ -0,0 +1,189 @@ + + + + + + + + + + MED3pa + + + + datasets + + + models + + + med3pa + + + + + + + + «uses» + + + + + + «uses» + + + + + + diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..b789ef2 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,84 @@ +.. Med3pa documentation documentation master file, created by + sphinx-quickstart on Sun Jun 9 10:04:07 2024. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to MED3pa documentation! +================================= + +Overview +======== +.. image:: ./diagrams/package_white_bg.svg + :alt: UML package diagram + :align: center + :width: 100% + +.. raw:: html + +
    + +Use and Application of the package +----------------------------------- +The ``MED3pa`` package is specifically designed to address critical challenges in deploying machine learning models, particularly focusing on the robustness and reliability of models under real-world conditions. +It provides comprehensive tools for evaluating model stability and performance in the face of **covariate shifts, and problematic data profiles.** + +Key functionalities +------------------- + +- **Uncertainty and Confidence Estimation:** Through the med3pa subpackage, the package measures the uncertainty and predictive confidence at both individual and group levels. This helps in understanding the reliability of model predictions and in making informed decisions based on model outputs. + +- **Identification of Problematic Profiles**: MED3pa analyzes data profiles that consistently lead to poor model performance. This capability allows developers to refine training datasets or retrain models to handle these edge cases effectively. + +Software Engineering Principles +------------------------------- +The ``MED3pa`` package is constructed with a strong emphasis on **software engineering principles**, making it a robust and scalable solution for machine learning model evaluation: + +- **Modular Design**: The package is structured into three distinct **subpackages** (med3pa, models, and datasets), each focusing on different aspects of model training and evaluation. This modularity allows users to easily integrate and utilize specific functionalities tailored to their needs without dealing with unnecessary complexities. +- **Extensibility**: Thanks to its modular architecture, the package can be effortlessly extended to include more functionalities or adapt to different use cases. New models, methods, or data handling procedures can be added with minimal impact on the existing system structure. +- **Use of Design Patterns**: MED3pa employs various design patterns that enhance its maintainability and usability. For example, the use of factory patterns in model creation and strategy patterns in handling different file extensions ensures that the system remains flexible and adaptable to new requirements. + +.. toctree:: + :maxdepth: 4 + :caption: Installation Guide + + installation + +Subpackages +============ +The package is structured into three distinct subpackages : ``datasets``, ``models`` and ``med3pa``, each focusing on different aspects of the package goals. +This modularity allows users to easily integrate and utilize specific functionalities tailored to their needs without dealing with unnecessary complexities. + +.. image:: ./diagrams/subpackages.svg + :alt: UML package diagram + :align: center + +.. raw:: html + +
    + +.. toctree:: + :maxdepth: 4 + :caption: Subpackages + + MED3pa + +Tutorials +============ +Welcome to the tutorials section of the med3pa documentation. Here, we offer comprehensive, step-by-step guides to help you effectively utilize the various subpackages within ``med3pa``. +Each tutorial is designed to enhance your understanding of the package's capabilities and to provide practical examples of how to integrate these tools into your data science and machine learning workflows. + +Explore each tutorial to learn how to make the most of MED3pa's robust functionalities for your projects. + +.. toctree:: + :maxdepth: 4 + :caption: Tutorials + + tutorials + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 0000000..bed9412 --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,31 @@ +Installation Guide +========================================== + +Welcome to the installation guide of ``MED3pa`` package. Follow the steps below to install and get started with the package. + +Prerequisites +------------- + +Before installing the package, ensure you have the following prerequisites: + +- Python 3.9 or later +- pip (Python package installer) + +Installation +------------ + +Step 1: Install via pip +~~~~~~~~~~~~~~~~~~~~~~~ + +You can install ``MED3pa`` directly from PyPI using pip:: + + pip install med3pa + +Step 2: Verify the Installation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To verify that the installation was successful, you can run the following command:: + + python -c "import MED3pa; print('Installation successful!')" + +If you see `Installation successful!`, then you are ready to start using ``MED3pa``. diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..954237b --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/med3pa_tutorials.rst b/docs/med3pa_tutorials.rst new file mode 100644 index 0000000..81b6c2f --- /dev/null +++ b/docs/med3pa_tutorials.rst @@ -0,0 +1,172 @@ +Working with the med3pa Subpackage +---------------------------------- +This tutorial guides you through the process of setting up and running comprehensive experiments using the ``med3pa`` subpackage. It includes steps to execute MED3pa experiment with ``Med3paExperiment``. + +Running the MED3pa Experiment +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Step 1: Setting up the Datasets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +First, configure the `DatasetsManager`. In the case of MED3pa only experiment you only need to set the DatasetManager with either `testing` and `reference` dataset: + +.. code-block:: python + + from MED3pa.datasets import DatasetsManager + + # Initialize the DatasetsManager + datasets = DatasetsManager() + + # Load datasets for reference, and testing + datasets.set_from_file(dataset_type="reference", file='./path_to_reference_data.csv', target_column_name='Outcome') + datasets.set_from_file(dataset_type="testing", file='./path_to_test_data.6.csv', target_column_name='Outcome') + + # Initialize the DatasetsManager + datasets2 = DatasetsManager() + + # Load datasets for reference, and testing + datasets2.set_from_file(dataset_type="reference", file='./data/test_data.csv', target_column_name='Outcome') + datasets2.set_from_file(dataset_type="testing", file='./data/test_data_shifted_1.6.csv', target_column_name='Outcome') + +Step 2: Configuring the Model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Next, utilize the ``ModelFactory`` to load a pre-trained model, and set it as the base model for the experiment. Alternatively, you can train your own model and use it. + +.. code-block:: python + + from MED3pa.models import BaseModelManager, ModelFactory + + # Initialize the model factory and load the pre-trained model + factory = ModelFactory() + model = factory.create_model_from_pickled("./path_to_model.pkl") + + # Set the base model using BaseModelManager + base_model_manager = BaseModelManager() + base_model_manager.set_base_model(model=model) + +Step 3: Running the med3pa Experiment +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Execute the MED3PA experiment with the specified datasets and base model. You can also specify other parameters as needed. See the documentation of the subpackage for more information about the parameters. + +The experiment outputs two structure one for the reference set and the other for the testing set, both containing files indicating the extracted profiles at different declaration rates, the performance of the model on these profiles..etc. + +.. code-block:: python + + from MED3pa.med3pa import Med3paExperiment + from MED3pa.med3pa.uncertainty import AbsoluteError + from MED3pa.models.concrete_regressors import RandomForestRegressorModel + + # Define parameters for the experiment + ipc_params = {'n_estimators': 100} + apc_params = {'max_depth': 3} + med3pa_metrics = ['Auc', 'Accuracy', 'BalancedAccuracy'] + + # Execute the MED3PA experiment + ipc_params = {'n_estimators': 100} + apc_params = {'max_depth': 3} + med3pa_metrics = ['Auc', 'Accuracy', 'BalancedAccuracy'] + + # Execute the MED3PA experiment + results = Med3paExperiment.run( + datasets_manager=datasets, + base_model_manager=base_model_manager, + uncertainty_metric="absolute_error", + ipc_type='RandomForestRegressor', + ipc_params=ipc_params, + apc_params=apc_params, + samples_ratio_min=0, + samples_ratio_max=10, + samples_ratio_step=5, + med3pa_metrics=med3pa_metrics, + evaluate_models=True, + models_metrics=['MSE', 'RMSE'] + ) + + results2 = Med3paExperiment.run( + datasets_manager=datasets2, + base_model_manager=base_model_manager, + uncertainty_metric="absolute_error", + ipc_type='RandomForestRegressor', + ipc_params=ipc_params, + apc_params=apc_params, + samples_ratio_min=0, + samples_ratio_max=10, + samples_ratio_step=5, + med3pa_metrics=med3pa_metrics, + evaluate_models=True, + models_metrics=['MSE', 'RMSE'] + ) + +Step 4: Analyzing and Saving the Results +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +After running the experiment, you can analyze and save the results using the returned ``Med3paResults`` instance. + +.. code-block:: python + + # Save the results to a specified directory + results.save(file_path='./med3pa_experiment_results/') + results2.save(file_path='./med3pa_experiment_results_2') + +Additonnally, you can save the instances the IPC and APC models as pickled files: + +.. code-block:: python + + results.save_models(file_path='./med3pa_experiment_results_models') + +Step 5: Running experiments from pretrained models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +If you don't want to re-train new APC and IPC models in your experiment, you can directly use a previously saved instances. as follows: + +.. code-block:: python + + from MED3pa.med3pa import Med3paExperiment + from MED3pa.med3pa.uncertainty import AbsoluteError + + # Define parameters for the experiment + ipc_params = {'n_estimators': 100} + apc_params = {'max_depth': 3} + med3pa_metrics = ['Auc', 'Accuracy', 'BalancedAccuracy'] + + # Execute the MED3PA experiment + results = Med3paExperiment.run( + datasets_manager=datasets, + base_model_manager=base_model_manager, + uncertainty_metric="absolute_error", + ipc_type='RandomForestRegressor', + pretrained_ipc='./med3pa_experiment_results_models/ipc_model.pkl', + pretrained_apc='./med3pa_experiment_results_models/apc_model.pkl', + samples_ratio_min=0, + samples_ratio_max=10, + samples_ratio_step=5, + med3pa_metrics=med3pa_metrics, + evaluate_models=True, + models_metrics=['MSE', 'RMSE'] + ) + + results2 = Med3paExperiment.run( + datasets_manager=datasets2, + base_model_manager=base_model_manager, + uncertainty_metric="absolute_error", + ipc_type='RandomForestRegressor', + pretrained_ipc='./med3pa_experiment_results_models/ipc_model.pkl', + pretrained_apc='./med3pa_experiment_results_models/apc_model.pkl', + samples_ratio_min=0, + samples_ratio_max=10, + samples_ratio_step=5, + med3pa_metrics=med3pa_metrics, + evaluate_models=True, + models_metrics=['MSE', 'RMSE'] + ) + + # Save the results to a specified directory + results.save(file_path='./med3pa_experiment_results_pretrained') + results2.save(file_path='./med3pa_experiment_results_2_pretrained') + + +.. code-block:: python + + from MED3pa.med3pa.comparaison import Med3paComparison + + comparaison = Med3paComparison('./med3pa_experiment_results_pretrained', './med3pa_experiment_results_2_pretrained') + comparaison.compare_experiments() + comparaison.save('./med3pa_comparaison_results') + diff --git a/docs/models_tutorials.rst b/docs/models_tutorials.rst new file mode 100644 index 0000000..4860484 --- /dev/null +++ b/docs/models_tutorials.rst @@ -0,0 +1,222 @@ +Working with the Models Subpackage +---------------------------------- + +The ``models`` subpackage is crafted to offer a comprehensive suite of tools for creating and managing various machine learning models within the ``MED3pa`` package. + +Using the ModelFactory Class +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The ``ModelFactory`` class within the ``models`` subpackage offers a streamlined approach to creating machine learning models, either from predefined configurations or from serialized states. Here’s how to leverage this functionality effectively: + +Step 1: Importing Necessary Modules +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Start by importing the required classes and utilities for model management: + +.. code-block:: python + + from pprint import pprint + from MED3pa.models import factories + +Step 2: Creating an Instance of ModelFactory +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Instantiate the ``ModelFactory``, which serves as your gateway to generating various model instances: + +.. code-block:: python + + factory = factories.ModelFactory() + +Step 3: Discovering Supported Models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Before creating a model, check which models are currently supported by the factory: + +.. code-block:: python + + print("Supported models:", factory.get_supported_models()) + +**Output**: + +.. code-block:: none + + Supported models: ['XGBoostModel'] + +With this knowledge, we can proceed to create a model with specific hyperparameters. + +Step 4: Specifying and Creating a Model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Define hyperparameters for an XGBoost model and use these to instantiate a model: + +.. code-block:: python + + xgb_params = { + 'objective': 'binary:logistic', + 'eval_metric': 'auc', + 'eta': 0.1, + 'max_depth': 6, + 'subsample': 0.8, + 'colsample_bytree': 0.8, + 'min_child_weight': 1, + 'nthread': 4, + 'tree_method': 'hist', + 'device': 'cpu' + } + + xgb_model = factory.create_model_with_hyperparams('XGBoostModel', xgb_params) + +Now, let’s inspect the model's configuration: + +.. code-block:: python + + pprint(xgb_model.get_info()) + +**Output**: + +.. code-block:: none + + {'data_preparation_strategy': 'ToDmatrixStrategy', + 'model': 'XGBoostModel', + 'model_type': 'Booster', + 'params': {'colsample_bytree': 0.8, + 'device': 'cpu', + 'eta': 0.1, + 'eval_metric': 'auc', + 'max_depth': 6, + 'min_child_weight': 1, + 'nthread': 4, + 'objective': 'binary:logistic', + 'subsample': 0.8, + 'tree_method': 'hist'}, + 'pickled_model': False} + +This gives us general information about the model, such as its ``data_preparation_strategy``, indicating that the input data for training, prediction, and evaluation will be transformed to ``Dmatrix`` to better suit the ``xgb.Booster`` model. It also retrieves the model's parameters, the underlying model instance class (``Booster`` in this case), and the wrapper class (``XGBoostModel`` in this case). Finally, it indicates whether this model has been created from a pickled file. + +Step 5: Loading a Model from a Serialized State +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +For pre-trained models, we can make use of the ``create_model_from_pickled`` method to load a model from its serialized (pickled) state. You only need to specify the path to this pickled file. This function will examine the pickled file and extract all necessary information. + +.. code-block:: python + + xgb_model_pkl = factory.create_model_from_pickled('path_to_model.pkl') + pprint(xgb_model_pkl.get_info()) + +**Output**: + +.. code-block:: none + + {'data_preparation_strategy': 'ToDmatrixStrategy', + 'model': 'XGBoostModel', + 'model_type': 'Booster', + 'params': {'alpha': 0, + 'base_score': 0.5, + 'boost_from_average': 1, + 'booster': 'gbtree', + 'cache_opt': 1, + ... + 'updater': 'grow_quantile_histmaker', + 'updater_seq': 'grow_quantile_histmaker', + 'validate_parameters': 0}, + 'pickled_model': True} + +Using the Model Class +~~~~~~~~~~~~~~~~~~~~~ +In this section, we will learn how to train, predict, and evaluate a machine learning model. For this, we will directly use the created model from the previous section. + +Step 1: Training the Model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Generate Training and Validation Data: + +Prepare the data for training and validation. The following example generates synthetic data for demonstration purposes: + +.. code-block:: python + + np.random.seed(0) + X_train = np.random.randn(1000, 10) + y_train = np.random.randint(0, 2, 1000) + X_val = np.random.randn(1000, 10) + y_val = np.random.randint(0, 2, 1000) + +Training the Model: + +When training a model, you can specify additional ``training_parameters``. If they are not specified, the model will use the initialization parameters. You can also specify whether you'd like to balance the training classes. + +.. code-block:: python + + training_params = { + 'eval_metric': 'logloss', + 'eta': 0.1, + 'max_depth': 6 + } + xgb_model.train(X_train, y_train, X_val, y_val, training_params, balance_train_classes=True) + +This process optimizes the model based on the specified hyperparameters and validation data to prevent overfitting. + +Step 2: Predicting Using the Trained Model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Model Prediction: + +Once the model is trained, use it to predict labels or probabilities on a new dataset. This step demonstrates predicting binary labels for the test data. The ``return_proba`` parameter specifies whether to return the ``predicted_probabilities`` or the ``predicted_labels``. The labels are calculated based on the ``threshold``. + +.. code-block:: python + + X_test = np.random.randn(1000, 10) + y_test = np.random.randint(0, 2, 1000) + y_pred = xgb_model.predict(X_test, return_proba=False, threshold=0.5) + +Step 3: Evaluating the Model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Evaluate the model's performance using various metrics to understand its effectiveness in making predictions. The supported metrics include Accuracy, AUC, Precision, Recall, and F1 Score, among others. The ``evaluate`` method will handle the model predictions and then evaluate the model based on these predictions. You only need to specify the test data. + +To retrieve the list of supported ``classification_metrics``, you can use ``ClassificationEvaluationMetrics.supported_metrics()``: + +.. code-block:: python + + from MED3pa.models import ClassificationEvaluationMetrics + + # Display supported metrics + print("Supported evaluation metrics:", ClassificationEvaluationMetrics.supported_metrics()) + + # Evaluate the model + evaluation_results = xgb_model.evaluate(X_test, y_test, eval_metrics=['Auc', 'Accuracy'], print_results=True) + +**Output**: + +.. code-block:: none + + Supported evaluation metrics: ['Accuracy', 'BalancedAccuracy', 'Precision', 'Recall', 'F1Score', 'Specificity', 'Sensitivity', 'Auc', 'LogLoss', 'Auprc', 'NPV', 'PPV', 'MCC'] + Evaluation Results: + Auc: 0.51 + Accuracy: 0.50 + +Step 4: Retrieving Model Information +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The ``get_info`` method provides detailed information about the model, including its type, parameters, data preparation strategy, and whether it's a pickled model. This is useful for understanding the configuration and state of the model. + +.. code-block:: python + + model_info = xgb_model.get_info() + pprint(model_info) + +**Output**: + +.. code-block:: none + + {'model': 'XGBoostModel', + 'model_type': 'Booster', + 'params': {'objective': 'binary:logistic', + 'eval_metric': 'auc', + 'eta': 0.1, + 'max_depth': 6, + 'subsample': 0.8, + 'colsample_bytree': 0.8, + 'min_child_weight': 1, + 'nthread': 4, + 'tree_method': 'hist', + 'device': 'cpu'}, + 'data_preparation_strategy': 'ToDmatrixStrategy', + 'pickled_model': False} + +Step 5: Saving Model Information +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +You can save the model by using the `save` method, which will save the underlying model instance as a pickled file, and the model's information as a .json file: + +.. code-block:: none + + xgb_model.save("./models/saved_model") \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..cbf1e36 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,2 @@ +sphinx +sphinx-rtd-theme diff --git a/docs/tutorials.rst b/docs/tutorials.rst new file mode 100644 index 0000000..765167b --- /dev/null +++ b/docs/tutorials.rst @@ -0,0 +1,7 @@ +.. toctree:: + :maxdepth: 4 + :caption: Tutorials + + datasets_tutorials + models_tutorials + med3pa_tutorials \ No newline at end of file diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..e97eb00 Binary files /dev/null and b/environment.yaml differ diff --git a/examples/oym_example.ipynb b/examples/oym_example.ipynb new file mode 100644 index 0000000..0113247 --- /dev/null +++ b/examples/oym_example.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# MED3pa utilisation example\n", + "\n", + "This tutorial guides you through the process of setting up and running comprehensive experiments using the `MED3pa` subpackage. It includes steps to execute MED3pa experiment with `Med3paExperiment`." + ], + "metadata": { + "collapsed": false + }, + "id": "b28b6f7deaed4364" + }, + { + "cell_type": "markdown", + "source": [ + "## Get Data and BaseModel." + ], + "metadata": { + "collapsed": false + }, + "id": "3018d55957753b12" + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "import wget\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "from MED3pa.datasets import DatasetsManager\n", + "from MED3pa.med3pa import Med3paExperiment\n", + "from MED3pa.models import BaseModelManager\n", + "from MED3pa.visualization.mdr_visualization import visualize_mdr\n", + "from MED3pa.visualization.profiles_visualization import visualize_tree" + ], + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-08-26T19:13:03.098273Z", + "start_time": "2025-08-26T19:13:03.092243100Z" + } + }, + "id": "initial_id", + "execution_count": 22 + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "def oym_process(df):\n", + " # One hot encoding of categorical variables\n", + " categorical_variables = ['living_status', 'admission_group', 'service_group']\n", + " df = pd.get_dummies(df, columns=categorical_variables)\n", + " df = pd.get_dummies(df, columns=['gender'], drop_first=True)\n", + "\n", + " # Convert boolean to int\n", + " boolean_variables = ['CSO', 'oym']\n", + " df[boolean_variables] = df[boolean_variables].astype(int)\n", + " \n", + " # Extract target and features\n", + " x_features = df.drop(columns=['oym'])\n", + " y = df['oym'].to_numpy()\n", + " return x_features, y" + ], + "metadata": { + "collapsed": false + }, + "id": "eda88d93fe2d932f", + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "# Get Data\n", + "os.makedirs(\"data/datasets/oym\", exist_ok=True)\n", + "wget.download('https://zenodo.org/records/12954673/files/dataset.csv?download=1',\n", + " out=os.getcwd() + '\\\\data\\\\datasets\\\\oym\\\\dataset.csv')\n", + "df = pd.read_csv(os.getcwd() + '\\\\data\\\\datasets\\\\oym\\\\dataset.csv')\n", + "\n", + "# Process Data\n", + "x, y = oym_process(df)\n", + "x_train, x_evaluation, y_train, y_evaluation = train_test_split(x, y, test_size=0.3, random_state=54288)\n", + "\n", + "# Train the BaseModel\n", + "clf = RandomForestClassifier(max_depth=4, random_state=42).fit(x_train, y_train)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2025-08-26T19:16:13.549526800Z", + "start_time": "2025-08-26T19:15:49.958593500Z" + } + }, + "id": "875a6ed7a0038e96", + "execution_count": 27 + }, + { + "cell_type": "markdown", + "source": [ + "## Define MED3pa experiment characteristics" + ], + "metadata": { + "collapsed": false + }, + "id": "34902be792aa9b46" + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "med3pa_params = {\"uncertainty_metric\":\"sigmoidal_error\",\n", + " \"ipc_type\":'RandomForestRegressor',\n", + " \"ipc_params\":{'n_estimators': 100},\n", + " \"apc_params\":{'max_depth': 6},\n", + " \"ipc_grid_params\":{'n_estimators': [50, 100, 200],\n", + " 'max_depth': [2, 4, 6]},\n", + " \"apc_grid_params\":{'min_samples_leaf': [2, 4, 6]},\n", + " \"samples_ratio_min\":0,\n", + " \"samples_ratio_max\":10,\n", + " \"samples_ratio_step\":5,\n", + " \"evaluate_models\":True}" + ], + "metadata": { + "collapsed": false + }, + "id": "15b499dfbee11fe4", + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "## MED3pa evaluation of BaseModel" + ], + "metadata": { + "collapsed": false + }, + "id": "cce8017f5093c365" + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "# Initialize the DatasetsManager\n", + "datasets = DatasetsManager()\n", + "datasets.set_from_data(dataset_type=\"testing\",\n", + " observations=x_evaluation.to_numpy(),\n", + " true_labels=y_evaluation,\n", + " column_labels=x_evaluation.columns)\n", + "# Initialize the BaseModelManager\n", + "base_model_manager = BaseModelManager(model=clf)\n", + "\n", + "# Execute the MED3PA experiment\n", + "results = Med3paExperiment.run(\n", + " datasets_manager=datasets,\n", + " base_model_manager=base_model_manager,\n", + " **med3pa_params\n", + ")\n", + "\n", + "# Save the results to a specified directory\n", + "results.save(file_path='results/oym')\n", + "\n", + "# Visualize results\n", + "visualize_mdr(result=results, filename='results/oym/mdr')\n", + "visualize_tree(result=results, filename='results/oym/profiles')" + ], + "metadata": { + "collapsed": false + }, + "id": "22486082f047c6b1", + "execution_count": null + }, + { + "cell_type": "code", + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "6910720dabd1c2ae" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/readthedocs.yml b/readthedocs.yml new file mode 100644 index 0000000..286b3c9 --- /dev/null +++ b/readthedocs.yml @@ -0,0 +1,17 @@ +version: 2 + +build: + os: ubuntu-20.04 + tools: + python: "3.12" + +sphinx: + configuration: docs/conf.py + +formats: + - pdf + - epub + +python: + install: + - requirements: docs/requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e9265e0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +checkpointer==2.1.0 +numpy>=1.21.0, <2.1.0 +pandas>=1.4.0, <3.0.0 +PyYAML>=5.4, <7.0 +matplotlib==3.9.2 +ray>=2.37.0 +relib==1.2.0 +scikit_learn>=1.0, <2.0 +scipy>=1.7, <2.0 +setuptools>=49.6.0, <70.0 +torch>=1.9.0, <3.0 +tqdm>=4.50.0, <5.0 +xgboost==2.1.1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..71eea52 --- /dev/null +++ b/setup.py @@ -0,0 +1,26 @@ +import os +from setuptools import setup, find_packages + +with open("README.md", encoding='utf-8') as f: + long_description = f.read() + +with open('requirements.txt') as f: + requirements = f.readlines() + +setup( + name="MED3pa", + version="1.0.0", + author="MEDomics consortium", + author_email="medomics.info@gmail.com", + description="Python Open-source package for ensuring robust and reliable ML models deployments", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/MEDomics-UdeS/MED3pa", + project_urls={ + 'Documentation': 'https://med3pa.readthedocs.io/en/latest/', + 'Github': 'https://github.com/MEDomics-UdeS/MED3pa' + }, + packages=find_packages(exclude=['docs', 'tests', 'experiments']), + python_requires='>=3.9', + install_requires=requirements, +)