Update training logic to computes loss over logits#990
Update training logic to computes loss over logits#990
Conversation
|
Hi @constantinpape, This should be good for a first review from my side! The intended idea here is simple: 1) compute loss over logit masks (to make this possible, I convert the labels to match the logits dimensions) and 2) work with iterative prompts for logits. Let me know how it looks! EDIT: I am gonna run a training for this and see how the results look for LIVECell! |
| self.log_image_interval = trainer.log_image_interval | ||
|
|
||
| def add_image(self, x, y, samples, name, step): | ||
| def add_image(self, x, y, name, step): |
There was a problem hiding this comment.
Was this never used / not used anymore?
There was a problem hiding this comment.
It was used previously, I removed it now because the samples are logits now. Would you recommend to keep those predictions by upsampling them?
There was a problem hiding this comment.
What exactly were the samples? Examples for mask predictions? It would maybe be good to keep them, we can discuss later.
There was a problem hiding this comment.
Yes, samples were mask predictions!
|
Here's a detailed description on the PR (we see when we can come back to this in future) Core Idea: Computing loss over predicted logits masks and downsampled ground-truth masks. The current implementation works, however it does not bring us significant memory advantages (below mentioned are quick try-outs on LIVECell): |
WIP!