|
| 1 | +#include <stdio.h> |
| 2 | +#include <string.h> |
| 3 | +#include <time.h> |
| 4 | +#include <iostream> |
| 5 | +#include <random> |
| 6 | +#include <string> |
| 7 | +#include <vector> |
| 8 | +#include "gosd.h" |
| 9 | + |
| 10 | +// #include "preprocessing.hpp" |
| 11 | +#include "flux.hpp" |
| 12 | +#include "stable-diffusion.h" |
| 13 | + |
| 14 | +#define STB_IMAGE_IMPLEMENTATION |
| 15 | +#define STB_IMAGE_STATIC |
| 16 | +#include "stb_image.h" |
| 17 | + |
| 18 | +#define STB_IMAGE_WRITE_IMPLEMENTATION |
| 19 | +#define STB_IMAGE_WRITE_STATIC |
| 20 | +#include "stb_image_write.h" |
| 21 | + |
| 22 | +#define STB_IMAGE_RESIZE_IMPLEMENTATION |
| 23 | +#define STB_IMAGE_RESIZE_STATIC |
| 24 | +#include "stb_image_resize.h" |
| 25 | + |
| 26 | +// Names of the sampler method, same order as enum sample_method in stable-diffusion.h |
| 27 | +const char* sample_method_str[] = { |
| 28 | + "euler_a", |
| 29 | + "euler", |
| 30 | + "heun", |
| 31 | + "dpm2", |
| 32 | + "dpm++2s_a", |
| 33 | + "dpm++2m", |
| 34 | + "dpm++2mv2", |
| 35 | + "ipndm", |
| 36 | + "ipndm_v", |
| 37 | + "lcm", |
| 38 | +}; |
| 39 | + |
| 40 | +// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h |
| 41 | +const char* schedule_str[] = { |
| 42 | + "default", |
| 43 | + "discrete", |
| 44 | + "karras", |
| 45 | + "exponential", |
| 46 | + "ays", |
| 47 | + "gits", |
| 48 | +}; |
| 49 | + |
| 50 | +sd_ctx_t* sd_c; |
| 51 | + |
| 52 | +sample_method_t sample_method; |
| 53 | + |
| 54 | +int load_model(char *model, char* options[], int threads, int diff) { |
| 55 | + fprintf (stderr, "Loading model!\n"); |
| 56 | + |
| 57 | + char *stableDiffusionModel = ""; |
| 58 | + if (diff == 1 ) { |
| 59 | + stableDiffusionModel = model; |
| 60 | + model = ""; |
| 61 | + } |
| 62 | + |
| 63 | + // decode options. Options are in form optname:optvale, or if booleans only optname. |
| 64 | + char *clip_l_path = ""; |
| 65 | + char *clip_g_path = ""; |
| 66 | + char *t5xxl_path = ""; |
| 67 | + char *vae_path = ""; |
| 68 | + char *scheduler = ""; |
| 69 | + char *sampler = ""; |
| 70 | + |
| 71 | + // If options is not NULL, parse options |
| 72 | + for (int i = 0; options[i] != NULL; i++) { |
| 73 | + char *optname = strtok(options[i], ":"); |
| 74 | + char *optval = strtok(NULL, ":"); |
| 75 | + if (optval == NULL) { |
| 76 | + optval = "true"; |
| 77 | + } |
| 78 | + |
| 79 | + if (!strcmp(optname, "clip_l_path")) { |
| 80 | + clip_l_path = optval; |
| 81 | + } |
| 82 | + if (!strcmp(optname, "clip_g_path")) { |
| 83 | + clip_g_path = optval; |
| 84 | + } |
| 85 | + if (!strcmp(optname, "t5xxl_path")) { |
| 86 | + t5xxl_path = optval; |
| 87 | + } |
| 88 | + if (!strcmp(optname, "vae_path")) { |
| 89 | + vae_path = optval; |
| 90 | + } |
| 91 | + if (!strcmp(optname, "scheduler")) { |
| 92 | + scheduler = optval; |
| 93 | + } |
| 94 | + if (!strcmp(optname, "sampler")) { |
| 95 | + sampler = optval; |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + int sample_method_found = -1; |
| 100 | + for (int m = 0; m < N_SAMPLE_METHODS; m++) { |
| 101 | + if (!strcmp(sampler, sample_method_str[m])) { |
| 102 | + sample_method_found = m; |
| 103 | + } |
| 104 | + } |
| 105 | + if (sample_method_found == -1) { |
| 106 | + fprintf(stderr, "Invalid sample method, default to EULER_A!\n"); |
| 107 | + sample_method_found = EULER_A; |
| 108 | + } |
| 109 | + sample_method = (sample_method_t)sample_method_found; |
| 110 | + |
| 111 | + int schedule_found = -1; |
| 112 | + for (int d = 0; d < N_SCHEDULES; d++) { |
| 113 | + if (!strcmp(scheduler, schedule_str[d])) { |
| 114 | + schedule_found = d; |
| 115 | + fprintf (stderr, "Found scheduler: %s\n", scheduler); |
| 116 | + |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + if (schedule_found == -1) { |
| 121 | + fprintf (stderr, "Invalid scheduler! using DEFAULT\n"); |
| 122 | + schedule_found = DEFAULT; |
| 123 | + } |
| 124 | + |
| 125 | + schedule_t schedule = (schedule_t)schedule_found; |
| 126 | + |
| 127 | + fprintf (stderr, "Creating context\n"); |
| 128 | + sd_ctx_t* sd_ctx = new_sd_ctx(model, |
| 129 | + clip_l_path, |
| 130 | + clip_g_path, |
| 131 | + t5xxl_path, |
| 132 | + stableDiffusionModel, |
| 133 | + vae_path, |
| 134 | + "", |
| 135 | + "", |
| 136 | + "", |
| 137 | + "", |
| 138 | + "", |
| 139 | + false, |
| 140 | + false, |
| 141 | + false, |
| 142 | + threads, |
| 143 | + SD_TYPE_COUNT, |
| 144 | + STD_DEFAULT_RNG, |
| 145 | + schedule, |
| 146 | + false, |
| 147 | + false, |
| 148 | + false, |
| 149 | + false); |
| 150 | + |
| 151 | + if (sd_ctx == NULL) { |
| 152 | + fprintf (stderr, "failed loading model (generic error)\n"); |
| 153 | + return 1; |
| 154 | + } |
| 155 | + fprintf (stderr, "Created context: OK\n"); |
| 156 | + |
| 157 | + sd_c = sd_ctx; |
| 158 | + |
| 159 | + return 0; |
| 160 | +} |
| 161 | + |
| 162 | +int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale) { |
| 163 | + |
| 164 | + sd_image_t* results; |
| 165 | + |
| 166 | + std::vector<int> skip_layers = {7, 8, 9}; |
| 167 | + |
| 168 | + fprintf (stderr, "Generating image\n"); |
| 169 | + |
| 170 | + results = txt2img(sd_c, |
| 171 | + text, |
| 172 | + negativeText, |
| 173 | + -1, //clip_skip |
| 174 | + cfg_scale, // sfg_scale |
| 175 | + 3.5f, |
| 176 | + width, |
| 177 | + height, |
| 178 | + sample_method, |
| 179 | + steps, |
| 180 | + seed, |
| 181 | + 1, |
| 182 | + NULL, |
| 183 | + 0.9f, |
| 184 | + 20.f, |
| 185 | + false, |
| 186 | + "", |
| 187 | + skip_layers.data(), |
| 188 | + skip_layers.size(), |
| 189 | + 0, |
| 190 | + 0.01, |
| 191 | + 0.2); |
| 192 | + |
| 193 | + if (results == NULL) { |
| 194 | + fprintf (stderr, "NO results\n"); |
| 195 | + return 1; |
| 196 | + } |
| 197 | + |
| 198 | + if (results[0].data == NULL) { |
| 199 | + fprintf (stderr, "Results with no data\n"); |
| 200 | + return 1; |
| 201 | + } |
| 202 | + |
| 203 | + fprintf (stderr, "Writing PNG\n"); |
| 204 | + |
| 205 | + fprintf (stderr, "DST: %s\n", dst); |
| 206 | + fprintf (stderr, "Width: %d\n", results[0].width); |
| 207 | + fprintf (stderr, "Height: %d\n", results[0].height); |
| 208 | + fprintf (stderr, "Channel: %d\n", results[0].channel); |
| 209 | + fprintf (stderr, "Data: %p\n", results[0].data); |
| 210 | + |
| 211 | + stbi_write_png(dst, results[0].width, results[0].height, results[0].channel, |
| 212 | + results[0].data, 0, NULL); |
| 213 | + fprintf (stderr, "Saved resulting image to '%s'\n", dst); |
| 214 | + |
| 215 | + // TODO: free results. Why does it crash? |
| 216 | + |
| 217 | + free(results[0].data); |
| 218 | + results[0].data = NULL; |
| 219 | + free(results); |
| 220 | + fprintf (stderr, "gen_image is done", dst); |
| 221 | + |
| 222 | + return 0; |
| 223 | +} |
| 224 | + |
| 225 | +int unload() { |
| 226 | + free_sd_ctx(sd_c); |
| 227 | +} |
| 228 | + |
0 commit comments