Skip to content

Commit c80eaca

Browse files
committed
Update Dreamstudio integration to fix bad image generation.
1 parent 872f5a4 commit c80eaca

7 files changed

Lines changed: 568 additions & 43 deletions

File tree

BlazorDiffusion/BlazorDiffusion.csproj

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@
5151
</ItemGroup>
5252

5353
<ItemGroup>
54-
<Protobuf Include="proto\generation.proto" GrpcServices="Client" />
54+
<Protobuf Include="proto\dashboard.proto" GrpcServices="None" />
55+
<Protobuf Include="proto\engines.proto" GrpcServices="None" />
56+
<Protobuf Include="proto\generation.proto" GrpcServices="Client" ProtoRoot="proto\" />
57+
<Protobuf Include="proto\project.proto" GrpcServices="None" />
58+
<Protobuf Include="proto\tensors.proto" GrpcServices="None" />
5559
</ItemGroup>
5660

5761
<Target Name="CreateAppDataFolderBuild" AfterTargets="AfterBuild">

BlazorDiffusion/StableDiffusionClient.cs

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public class DreamStudioClient : IStableDiffusionClient
1919
public string EngineId { get; set; }
2020
public string? PublicPrefix { get; set; }
2121
public IVirtualFiles VirtualFiles { get; set; }
22-
22+
2323
public DreamStudioClient()
2424
{
2525
var credentials = CallCredentials.FromInterceptor((context, metadata) =>
@@ -28,6 +28,7 @@ public DreamStudioClient()
2828
{
2929
metadata.Add("Authorization", $"Bearer {ApiKey}");
3030
}
31+
3132
return Task.CompletedTask;
3233
});
3334
channel = GrpcChannel.ForAddress("https://grpc.stability.ai", new GrpcChannelOptions
@@ -36,35 +37,45 @@ public DreamStudioClient()
3637
});
3738
client = new GenerationService.GenerationServiceClient(channel);
3839
}
40+
3941
public async Task<ImageGenerationResponse> GenerateImageAsync(ImageGeneration request)
4042
{
41-
4243
var generateRequest = new Request
4344
{
44-
EngineId = string.IsNullOrEmpty(EngineId) ? DefaultEngineId : EngineId,
45+
EngineId = DefaultEngineId,
4546
RequestId = Guid.NewGuid().ToString("D"),
47+
RequestedType = ArtifactType.ArtifactImage,
4648
Image = new ImageParameters
4749
{
4850
Height = Convert.ToUInt32(request.Height),
4951
Width = Convert.ToUInt32(request.Width),
5052
Seed = { Convert.ToUInt32(request.Seed) },
5153
Steps = Convert.ToUInt32(request.Steps),
5254
Samples = Convert.ToUInt32(request.Images),
53-
Transform = new TransformType
55+
// Transform = new TransformType
56+
// {
57+
// Diffusion = DiffusionSampler.SamplerKLms
58+
// },
59+
Parameters =
5460
{
55-
Diffusion = DiffusionSampler.SamplerKLms
61+
new StepParameter
62+
{
63+
Guidance = new GuidanceParameters
64+
{
65+
GuidancePreset = GuidancePreset.Simple
66+
},
67+
Sampler = new SamplerParameters
68+
{
69+
CfgScale = 7.0f
70+
}
71+
}
5672
}
5773
},
5874
Prompt =
5975
{
6076
new Prompt()
6177
{
62-
Text = request.Prompt,
63-
Parameters = new PromptParameters
64-
{
65-
Init = false,
66-
Weight = 0.0f
67-
}
78+
Text = request.Prompt
6879
},
6980
},
7081
};
@@ -99,6 +110,7 @@ public async Task<ImageGenerationResponse> GenerateImageAsync(ImageGeneration re
99110
});
100111
}
101112
}
113+
102114
return new ImageGenerationResponse
103115
{
104116
RequestId = generateRequest.RequestId,
@@ -115,13 +127,13 @@ public async Task SaveMetadataAsync(Creative creative)
115127
{
116128
var vfsPathSuffix = creative.Key;
117129
var outputDir = Path.Join(OutputPathPrefix, vfsPathSuffix);
118-
await VirtualFiles.WriteFileAsync(Path.Join(outputDir,"metadata.json"), creative.ToJson().IndentJson());
130+
await VirtualFiles.WriteFileAsync(Path.Join(outputDir, "metadata.json"), creative.ToJson().IndentJson());
119131
}
120132

121133
public Task DeleteFolderAsync(Creative creative)
122134
{
123135
var vfsPathSuffix = creative.Key;
124-
var directory = VirtualFiles.GetDirectory(Path.Join(OutputPathPrefix,vfsPathSuffix));
136+
var directory = VirtualFiles.GetDirectory(Path.Join(OutputPathPrefix, vfsPathSuffix));
125137
var allFiles = directory.GetAllFiles();
126138
VirtualFiles.DeleteFiles(allFiles);
127139
return Task.CompletedTask;
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
syntax = 'proto3';
2+
package gooseai;
3+
option go_package = "./;dashboard";
4+
5+
enum OrganizationRole {
6+
MEMBER = 0;
7+
ACCOUNTANT = 1;
8+
OWNER = 2;
9+
}
10+
11+
message OrganizationMember {
12+
Organization organization = 1;
13+
optional User user = 2;
14+
OrganizationRole role = 3;
15+
bool is_default = 4;
16+
}
17+
18+
message OrganizationGrant {
19+
double amount_granted = 1;
20+
double amount_used = 2;
21+
uint64 expires_at = 3;
22+
uint64 granted_at = 4;
23+
}
24+
25+
message OrganizationPaymentInfo {
26+
double balance = 1;
27+
repeated OrganizationGrant grants = 2;
28+
}
29+
30+
message OrganizationAutoCharge {
31+
bool enabled = 1;
32+
string id = 2;
33+
uint64 created_at = 3;
34+
}
35+
36+
message Organization {
37+
string id = 1;
38+
string name = 2;
39+
string description = 3;
40+
repeated OrganizationMember members = 4;
41+
optional OrganizationPaymentInfo payment_info = 5;
42+
optional string stripe_customer_id = 6;
43+
optional OrganizationAutoCharge auto_charge = 7;
44+
}
45+
46+
message APIKey {
47+
string key = 1;
48+
bool is_secret = 2;
49+
uint64 created_at = 3;
50+
}
51+
52+
message User {
53+
string id = 1;
54+
optional string auth_id = 2;
55+
string profile_picture = 3;
56+
string email = 4;
57+
repeated OrganizationMember organizations = 5;
58+
repeated APIKey api_keys = 7;
59+
uint64 created_at = 8;
60+
optional bool email_verified = 9;
61+
}
62+
63+
message CostData {
64+
uint32 amount_tokens = 1;
65+
double amount_credits = 2;
66+
}
67+
68+
message UsageMetric {
69+
string operation = 1;
70+
string engine = 2;
71+
CostData input_cost = 3;
72+
CostData output_cost = 4;
73+
optional string user = 5;
74+
uint64 aggregation_timestamp = 6;
75+
}
76+
77+
message CostTotal {
78+
uint32 amount_tokens = 1;
79+
double amount_credits = 2;
80+
}
81+
82+
message TotalMetricsData {
83+
CostTotal input_total = 1;
84+
CostTotal output_total = 2;
85+
}
86+
87+
message Metrics {
88+
repeated UsageMetric metrics = 1;
89+
TotalMetricsData total = 2;
90+
}
91+
92+
message EmptyRequest {}
93+
94+
message GetOrganizationRequest {
95+
string id = 1;
96+
}
97+
98+
message GetMetricsRequest {
99+
string organization_id = 1;
100+
optional string user_id = 2;
101+
uint64 range_from = 3;
102+
uint64 range_to = 4;
103+
bool include_per_request_metrics = 5;
104+
}
105+
106+
message APIKeyRequest {
107+
bool is_secret = 1;
108+
}
109+
110+
message APIKeyFindRequest {
111+
string id = 1;
112+
}
113+
114+
message UpdateDefaultOrganizationRequest {
115+
string organization_id = 1;
116+
}
117+
118+
message ClientSettings {
119+
bytes settings = 1;
120+
}
121+
122+
message CreateAutoChargeIntentRequest {
123+
string organization_id = 1;
124+
uint64 monthly_maximum = 2;
125+
uint64 minimum_value = 3;
126+
uint64 amount_credits = 4;
127+
}
128+
129+
message CreateChargeRequest {
130+
uint64 amount = 1;
131+
string organization_id = 2;
132+
}
133+
134+
message GetChargesRequest {
135+
string organization_id = 1;
136+
uint64 range_from = 2;
137+
uint64 range_to = 3;
138+
}
139+
140+
message Charge {
141+
string id = 1;
142+
bool paid = 2;
143+
string receipt_link = 3;
144+
string payment_link = 4;
145+
uint64 created_at = 5;
146+
uint64 amount_credits = 6;
147+
}
148+
149+
message Charges {
150+
repeated Charge charges = 1;
151+
}
152+
153+
message GetAutoChargeRequest {
154+
string organization_id = 1;
155+
}
156+
157+
message AutoChargeIntent {
158+
string id = 1;
159+
string payment_link = 2;
160+
uint64 created_at = 3;
161+
uint64 monthly_maximum = 4;
162+
uint64 minimum_value = 5;
163+
uint64 amount_credits = 6;
164+
}
165+
166+
message UpdateUserInfoRequest {
167+
optional string email = 1;
168+
}
169+
170+
message UserPasswordChangeTicket {
171+
string ticket = 1;
172+
}
173+
174+
service DashboardService {
175+
// Get info
176+
rpc GetMe (EmptyRequest) returns (User);
177+
rpc GetOrganization (GetOrganizationRequest) returns (Organization);
178+
rpc GetMetrics (GetMetricsRequest) returns (Metrics);
179+
180+
// API key management
181+
rpc CreateAPIKey (APIKeyRequest) returns (APIKey);
182+
rpc DeleteAPIKey (APIKeyFindRequest) returns (APIKey);
183+
184+
// User settings
185+
rpc UpdateDefaultOrganization (UpdateDefaultOrganizationRequest) returns (User);
186+
rpc GetClientSettings (EmptyRequest) returns (ClientSettings);
187+
rpc SetClientSettings (ClientSettings) returns (ClientSettings);
188+
rpc UpdateUserInfo (UpdateUserInfoRequest) returns (User);
189+
rpc CreatePasswordChangeTicket (EmptyRequest) returns (UserPasswordChangeTicket);
190+
rpc DeleteAccount (EmptyRequest) returns (User);
191+
192+
// Payment functions
193+
rpc CreateCharge (CreateChargeRequest) returns (Charge);
194+
rpc GetCharges (GetChargesRequest) returns (Charges);
195+
rpc CreateAutoChargeIntent (CreateAutoChargeIntentRequest) returns (AutoChargeIntent);
196+
rpc UpdateAutoChargeIntent (CreateAutoChargeIntentRequest) returns (AutoChargeIntent);
197+
rpc GetAutoChargeIntent (GetAutoChargeRequest) returns (AutoChargeIntent);
198+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
syntax = 'proto3';
2+
package gooseai;
3+
option go_package = "./;engines";
4+
5+
// Possible engine type
6+
enum EngineType {
7+
TEXT = 0;
8+
PICTURE = 1;
9+
AUDIO = 2;
10+
VIDEO = 3;
11+
CLASSIFICATION = 4;
12+
STORAGE = 5;
13+
}
14+
15+
enum EngineTokenizer {
16+
GPT2 = 0;
17+
PILE = 1;
18+
}
19+
20+
// Engine info struct
21+
message EngineInfo {
22+
string id = 1;
23+
string owner = 2;
24+
bool ready = 3;
25+
EngineType type = 4;
26+
EngineTokenizer tokenizer = 5;
27+
string name = 6;
28+
string description = 7;
29+
}
30+
31+
message ListEnginesRequest {
32+
// Empty
33+
}
34+
35+
// Engine info list
36+
message Engines {
37+
repeated EngineInfo engine = 1;
38+
}
39+
40+
service EnginesService {
41+
rpc ListEngines (ListEnginesRequest) returns (Engines) {};
42+
}

0 commit comments

Comments
 (0)